Unverified Commit ed455174 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] evaluator - step2 (#4992)

parent a689e619
......@@ -5,3 +5,9 @@ Quickstart
PyTorch </tutorials/hpo_quickstart_pytorch/main>
TensorFlow </tutorials/hpo_quickstart_tensorflow/main>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_pytorch/index
/tutorials/hpo_quickstart_tensorflow/index
Evaluator
=========
.. _compression-torch-evaluator:
TorchEvaluator
--------------
.. autoclass:: nni.compression.pytorch.TorchEvaluator
.. _compression-lightning-evaluator:
LightningEvaluator
------------------
.. autoclass:: nni.compression.pytorch.LightningEvaluator
......@@ -8,5 +8,6 @@ Compression API Reference
Quantizer <quantizer>
Pruning Speedup <pruning_speedup>
Quantization Speedup <quantization_speedup>
Evaluator <evaluator>
Compression Utilities <utils>
Framework Related <framework>
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .utils import LightningEvaluator, TorchEvaluator
......@@ -119,7 +119,8 @@ class Compressor:
Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
if self._modules_to_compress is None:
self._modules_to_compress = []
......@@ -146,7 +147,8 @@ class Compressor:
Optional[Dict]
The retrieved configuration for this layer, if None, this layer should not be compressed.
"""
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, err_msg
ret = None
for config in self.config_list:
......@@ -240,8 +242,10 @@ class Compressor:
Dict[int, List[str]]
A dict. The key is the config idx in config_list, the value is the module name list. i.e., {1: ['layer.0', 'layer.2']}.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, err_msg
self._unwrap_model()
module_groups = {}
......@@ -323,6 +327,8 @@ class Compressor:
torch.nn.Module
model with specified modules compressed.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, err_msg
return self.bound_model
......@@ -43,8 +43,8 @@ class PrunerModuleWrapper(Module):
pruning_target_mask_name = '{}_mask'.format(pruning_target_name)
pruning_target = getattr(self.module, pruning_target_name, None)
if hasattr(self.module, pruning_target_name) and pruning_target is not None:
setattr(self, pruning_target_name, Parameter(torch.empty(pruning_target.shape)))
self.register_buffer(pruning_target_mask_name, torch.ones(pruning_target.shape))
setattr(self, pruning_target_name, Parameter(torch.empty_like(pruning_target)))
self.register_buffer(pruning_target_mask_name, torch.ones_like(pruning_target))
else:
self.register_buffer(pruning_target_mask_name, None)
......@@ -67,11 +67,11 @@ class PrunerModuleWrapper(Module):
The best place to call this function is in `Pruner._unwrap_model()`.
"""
delattr(self.module, 'weight')
self.module.weight = Parameter(torch.empty(self.weight.size()))
self.module.weight = Parameter(torch.empty_like(self.weight))
self.module.weight.data = torch.mul(self.weight, self.weight_mask)
if hasattr(self.module, 'bias') and self.module.bias is not None:
delattr(self.module, 'bias')
self.module.bias = Parameter(torch.empty(self.bias.size()))
self.module.bias = Parameter(torch.empty_like(self.bias))
self.module.bias.data = torch.mul(self.bias, self.bias_mask)
def forward(self, *inputs):
......@@ -130,7 +130,8 @@ class Pruner(Compressor):
Wrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
if not self.is_wrapped:
for _, wrapper in reversed(list(self.get_modules_wrapper().items())):
......@@ -143,7 +144,8 @@ class Pruner(Compressor):
Unwrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
if self.is_wrapped:
for wrapper in self.get_modules_wrapper().values():
......@@ -165,8 +167,10 @@ class Pruner(Compressor):
self._unwrap_model()
parameter_name_map = {}
for name, param in self.bound_model.named_parameters():
# If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`, the name will not change after wrap.
# If the parameter name in under wrapped module is others, the name `xxx.param` will change to `xxx.module.param` after wrap.
# If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`,
# the name will not change after wrap.
# If the parameter name in under wrapped module is others,
# the name `xxx.param` will change to `xxx.module.param` after wrap.
parameter_name_map[name] = wrapped_param_names[id(param)] if id(param) in wrapped_param_names else name
self._wrap_model()
return parameter_name_map
......@@ -183,14 +187,12 @@ class Pruner(Compressor):
The masks dict with format {'op_name': {'weight': mask, 'bias': mask}}.
"""
wrappers = self.get_modules_wrapper()
for name, layer_mask in masks.items():
assert name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(name)
if layer_mask.get('weight') is not None:
assert hasattr(wrappers[name], 'weight_mask'), 'There is no attribute weight_mask in wrapper.'
setattr(wrappers[name], 'weight_mask', layer_mask.get('weight'))
if layer_mask.get('bias') is not None:
assert hasattr(wrappers[name], 'bias_mask'), 'There is no attribute bias_mask in wrapper.'
setattr(wrappers[name], 'bias_mask', layer_mask.get('bias'))
for module_name, target_masks in masks.items():
assert module_name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(module_name)
for target_name, target_mask in target_masks.items():
assert hasattr(wrappers[module_name], f'{target_name}_mask'), f'There is no attribute {target_name}_mask in wrapper.'
target: Tensor = getattr(self.get_modules_wrapper()[module_name], target_name)
setattr(wrappers[module_name], f'{target_name}_mask', target_mask.to(target.device))
def compress(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]]]:
"""
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Callable, Optional, cast
from typing import Dict, List, Callable, Optional, cast, overload
import json_tricks
import torch
......@@ -11,12 +13,13 @@ from torch import Tensor
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity, config_list_canonical
from nni.compression.pytorch.utils import count_flops_params
from .iterative_pruner import IterativePruner, PRUNER_DICT
from .tools import TaskGenerator
from .tools.rl_env import DDPG, AMCEnv
from ..utils import LightningEvaluator, TorchEvaluator, compute_sparsity, config_list_canonical
from ..utils.docstring import _EVALUATOR_DOCSTRING
class AMCTaskGenerator(TaskGenerator):
......@@ -41,8 +44,8 @@ class AMCTaskGenerator(TaskGenerator):
ddpg_params
The ddpg agent parameters.
target : str
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse, but in AMC, you can choose flops sparse.
This parameter is used to explain what the sparsity setting in config_list refers to.
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse,
but in AMC, you can choose flops sparse. This parameter is used to explain what the sparsity setting in config_list refers to.
"""
def __init__(self, total_episode: int, dummy_input: Tensor, origin_model: Module, origin_config_list: List[Dict],
......@@ -56,7 +59,7 @@ class AMCTaskGenerator(TaskGenerator):
self.config_list_copy = deepcopy(origin_config_list)
super().__init__(origin_model=origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='maximize')
def init_pending_tasks(self) -> List[Task]:
origin_model = torch.load(self._origin_model_path)
......@@ -82,6 +85,8 @@ class AMCTaskGenerator(TaskGenerator):
return self.generate_tasks(task_result)
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
self.temp_config_list = self.temp_config_list if hasattr(self, 'temp_config_list') else []
# append experience & update agent policy
if self.action is not None:
action, reward, observation, done = self.env.step(self.action, task_result.compact_model)
......@@ -106,7 +111,8 @@ class AMCTaskGenerator(TaskGenerator):
origin_model = torch.load(self._origin_model_path)
compact_model = task_result.compact_model
compact_model_masks = task_result.compact_model_masks
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.temp_config_list)
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks,
self.temp_config_list) # type: ignore
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.config_list_copy)
self._tasks[task_result.task_id].state['current_total_sparsity'] = current2origin_sparsity
......@@ -162,7 +168,7 @@ class AMCTaskGenerator(TaskGenerator):
class AMCPruner(IterativePruner):
r"""
__doc__ = r"""
AMC pruner leverages reinforcement learning to provide the model compression policy.
According to the author, this learning-based compression policy outperforms conventional rule-based compression policy by having a higher compression ratio,
better preserving the accuracy and freeing human labor.
......@@ -186,10 +192,11 @@ class AMCPruner(IterativePruner):
- op_names : Operation name to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
dummy_input : torch.Tensor
`dummy_input` is required for speedup and tracing the model in RL environment.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
pruning_algorithm : str
Supported pruning algorithm ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
......@@ -197,8 +204,6 @@ class AMCPruner(IterativePruner):
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
ddpg_params : Dict
Configuration dict to configure the DDPG agent, any key unset will be set to default implicitly.
- hidden1: hidden num of first fully connect layer. Default: 300
......@@ -223,23 +228,42 @@ class AMCPruner(IterativePruner):
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse, but in AMC, you can choose flops sparse.
This parameter is used to explain what the sparsity setting in config_list refers to.
Examples
--------
>>> from nni.compression.pytorch.pruning import AMCPruner
>>> config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.5, 'max_sparsity_per_layer': 0.8}]
>>> dummy_input = torch.rand(...).to(device)
>>> evaluator = ...
>>> finetuner = ...
>>> pruner = AMCPruner(400, model, config_list, dummy_input, evaluator, finetuner=finetuner)
>>> pruner.compress()
Notes
-----
The full script can be found :githublink:`here <examples/model_compress/pruning/amc_pruning_torch.py>`.
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator,
pruning_algorithm: str = 'l1', log_dir: str = '.', keep_intermediate_result: bool = False,
ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'):
...
@overload
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], dummy_input: Tensor,
evaluator: Callable[[Module], float], pruning_algorithm: str = 'l1', log_dir: str = '.',
keep_intermediate_result: bool = False, finetuner: Optional[Callable[[Module], None]] = None,
ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'):
...
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], *args, **kwargs):
new_api = ['evaluator', 'pruning_algorithm', 'log_dir', 'keep_intermediate_result', 'ddpg_params', 'pruning_params', 'target']
new_init_kwargs = {'pruning_algorithm': 'l1', 'log_dir': '.', 'keep_intermediate_result': False,
'ddpg_params': {}, 'pruning_params': {}, 'target': 'flops'}
old_api = ['dummy_input', 'evaluator', 'pruning_algorithm', 'log_dir', 'keep_intermediate_result', 'finetuner', 'ddpg_params',
'pruning_params', 'target']
old_init_kwargs = {'pruning_algorithm': 'l1', 'log_dir': '.', 'keep_intermediate_result': False, 'finetuner': None,
'ddpg_params': {}, 'pruning_params': {}, 'target': 'flops'}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
pruning_algorithm = init_kwargs['pruning_algorithm']
log_dir = init_kwargs['log_dir']
keep_intermediate_result = init_kwargs['keep_intermediate_result']
ddpg_params = init_kwargs['ddpg_params']
pruning_params = init_kwargs['pruning_params']
target = init_kwargs['target']
dummy_input = self.dummy_input if not self.using_evaluator else self.evaluator.get_dummy_input()
assert pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo'], \
"Only support pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo']"
task_generator = AMCTaskGenerator(total_episode=total_episode,
......@@ -251,5 +275,9 @@ class AMCPruner(IterativePruner):
ddpg_params=ddpg_params,
target=target)
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=True, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=True, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=True, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=False) # type: ignore
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Callable, Optional
from typing import Dict, List, Callable, Optional, overload
from torch import Tensor
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
from .basic_pruner import ADMMPruner
from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner
from .tools import LotteryTicketTaskGenerator
from ..utils import LightningEvaluator, TorchEvaluator, OptimizerConstructHelper
from ..utils.docstring import _EVALUATOR_DOCSTRING
_logger = logging.getLogger(__name__)
......@@ -21,10 +23,7 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, sa_params: Dict = {}, log_dir: str = '.',
keep_intermediate_result: bool = False):
self.iterative_pruner = SimulatedAnnealingPruner(model=None,
config_list=None,
log_dir=Path(log_dir, 'SA'),
**sa_params)
self._sa_params = sa_params
super().__init__(total_iteration=total_iteration,
origin_model=origin_model,
origin_config_list=origin_config_list,
......@@ -36,12 +35,20 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
# TODO: replace with validation here
for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
warn_msg = 'Only `total_sparsity` can be differentially allocated sparse ratio to each layer, ' + \
'`sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. ' + \
'Make sure you know what this will lead to, otherwise please use `total_sparsity`.'
_logger.warning(warn_msg)
return super().reset(model, config_list, masks)
def _iterative_pruner_reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.iterative_pruner.task_generator._log_dir = Path(self._log_dir_root, 'SA')
self.iterative_pruner.reset(model, config_list=config_list, masks=masks)
if not hasattr(self, 'iterative_pruner'):
self.iterative_pruner = SimulatedAnnealingPruner(model=model,
config_list=config_list,
log_dir=Path(self._log_dir_root, 'SA'),
**self._sa_params)
else:
self.iterative_pruner.reset(model, config_list=config_list, masks=masks)
def allocate_sparsity(self, new_config_list: List[Dict], model: Module, masks: Dict[str, Dict[str, Tensor]]):
self._iterative_pruner_reset(model, new_config_list, masks)
......@@ -53,8 +60,9 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
class AutoCompressPruner(IterativePruner):
r"""
__doc__ = r"""
For total iteration number :math:`N`, AutoCompressPruner prune the model that survive the previous iteration for a fixed sparsity ratio (e.g., :math:`1-{(1-0.8)}^{(1/N)}`) to achieve the overall sparsity (e.g., :math:`0.8`):
""" + r"""
.. code-block:: bash
......@@ -65,35 +73,27 @@ class AutoCompressPruner(IterativePruner):
Parameters
----------
model : Module
model
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
config_list
The origin config list provided by the user.
total_iteration : int
total_iteration
The total iteration number.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
admm_params : Dict
admm_params
The parameters passed to the ADMMPruner.
- trainer : Callable[[Module, Optimizer, Callable].
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
- traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
- criterion : Callable[[Tensor, Tensor], Tensor].
The criterion function used in trainer. Take model output and target value as input, and return the loss.
- evaluator : LightningEvaluator or TorchEvaluator.
The same with the evaluator of AutoCompressPruner input parameter.
- iterations : int.
The total iteration number in admm pruning algorithm.
- training_epochs : int.
The epoch number for training model in each iteration.
sa_params : Dict
sa_params
The parameters passed to the SimulatedAnnealingPruner.
- evaluator : Callable[[Module], float]. Required.
Evaluate the pruned model and give a score.
- evaluator : LightningEvaluator or TorchEvaluator.
The same with the evaluator of AutoCompressPruner input parameter.
- start_temperature : float. Default: `100`.
Start temperature of the simulated annealing process.
- stop_temperature : float. Default: `20`.
......@@ -104,54 +104,50 @@ class AutoCompressPruner(IterativePruner):
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
- pruning_algorithm : str. Default: `'level'`.
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
- pruning_params : Dict. Default: `{}`.
- pruning_params : Dict. Default: dict().
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
log_dir : str
log_dir
The log directory used to save the result, you can find the best result under this folder.
keep_intermediate_result : bool
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handles all finetune logic, takes a pytorch module as input.
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
speedup : bool
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
Examples
--------
>>> import nni
>>> from nni.compression.pytorch.pruning import AutoCompressPruner
>>> model = ...
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
>>> trainer = ...
>>> criterion = ...
>>> evaluator = ...
>>> finetuner = ...
>>> admm_params = {
>>> 'trainer': trainer,
>>> 'traced_optimizer': traced_optimizer,
>>> 'criterion': criterion,
>>> 'iterations': 10,
>>> 'training_epochs': 1
>>> }
>>> sa_params = {
>>> 'evaluator': evaluator
>>> }
>>> pruner = AutoCompressPruner(model, config_list, 10, admm_params, sa_params, finetuner=finetuner)
>>> pruner.compress()
>>> _, model, masks, _, _ = pruner.get_best_result()
Notes
-----
The full script can be found :githublink:`here <examples/model_compress/pruning/auto_compress_pruner.py>`.
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False,
dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None):
...
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
*args, **kwargs):
new_api = ['evaluator', 'speedup']
new_init_kwargs = {'evaluator': None, 'speedup': False}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
speedup = init_kwargs['speedup']
task_generator = AutoCompressTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
......@@ -175,6 +171,10 @@ class AutoCompressPruner(IterativePruner):
else:
admm_params['granularity'] = 'fine-grained'
pruner = ADMMPruner(None, None, **admm_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
pruner = ADMMPruner(None, None, **admm_params) # type: ignore
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=False) # type: ignore
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy
import functools
import logging
from typing import List, Dict, Tuple, Callable, Optional
from typing import List, Dict, Tuple, Callable, Optional, overload
from schema import And, Or, Optional as SchemaOptional, SchemaError
import torch
......@@ -17,7 +20,13 @@ from ..base import Pruner
from .tools import (
DataCollector,
HookCollectorInfo,
WeightDataCollector,
TargetDataCollector,
EvaluatorBasedTargetDataCollector,
EvaluatorBasedHookDataCollector
)
# TODO: remove in nni v3.0.
from .tools import (
WeightTrainerBasedDataCollector,
SingleHookTrainerBasedDataCollector
)
......@@ -25,7 +34,7 @@ from .tools import (
from .tools import (
MetricsCalculator,
NormMetricsCalculator,
MultiDataNormMetricsCalculator,
HookDataNormMetricsCalculator,
DistMetricsCalculator,
APoZRankMetricsCalculator,
MeanRankMetricsCalculator
......@@ -39,7 +48,19 @@ from .tools import (
DependencyAwareAllocator
)
from ..utils import CompressorSchema, config_list_canonical, OptimizerConstructHelper, Scaling
from ..utils import (
CompressorSchema,
OptimizerConstructHelper,
Scaling,
Evaluator,
LightningEvaluator,
TorchEvaluator,
ForwardHook,
TensorHook,
config_list_canonical
)
from ..utils.docstring import _EVALUATOR_DOCSTRING
_logger = logging.getLogger(__name__)
......@@ -77,12 +98,9 @@ INTERNAL_SCHEMA = {
class BasicPruner(Pruner):
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]]):
self.data_collector: Optional[DataCollector] = None
self.metrics_calculator: Optional[MetricsCalculator] = None
self.sparsity_allocator: Optional[SparsityAllocator] = None
super().__init__(model, config_list)
data_collector: DataCollector
metrics_calculator: MetricsCalculator
sparsity_allocator: SparsityAllocator
def validate_config(self, model: Module, config_list: List[Dict]):
self._validate_config_before_canonical(model, config_list)
......@@ -114,7 +132,8 @@ class BasicPruner(Pruner):
Tuple[Module, Dict]
Return the wrapped model and mask.
"""
assert self.bound_model is not None and self.config_list is not None, 'Model and/or config_list are not set in this pruner, please set them by reset() before compress().'
err_msg = 'Model and/or config_list are not set in this pruner, please set them by reset() before compress().'
assert self.bound_model is not None and self.config_list is not None, err_msg
assert self.data_collector is not None and self.metrics_calculator is not None and self.sparsity_allocator is not None
data = self.data_collector.collect()
_logger.debug('Collected Data:\n%s', data)
......@@ -126,6 +145,67 @@ class BasicPruner(Pruner):
return self.bound_model, masks
_LEGACY_TRAINER = Callable[[Module, Optimizer, Callable], None]
_LEGACY_CRITERION = Callable[[Tensor, Tensor], Tensor]
# TODO: remove in nni v3.0.
class EvaluatorBasedPruner(BasicPruner):
evaluator: LightningEvaluator | TorchEvaluator
using_evaluator: bool
trainer: _LEGACY_TRAINER
traced_optimizer: Optimizer
criterion: _LEGACY_CRITERION
def _init_evaluator(self, model: Module, new_api: List[str], old_api: List[str], init_kwargs: Dict, args: Tuple,
kwargs: Dict) -> Dict:
# for fake __init__ overload, parsing args and kwargs, initializing evaluator or [trainer, traced_optimizer, criterion],
# return the remaining arguments.
if (len(args) > 0 and isinstance(args[0], Evaluator)) or (len(args) == 0 and isinstance(kwargs.get('evaluator', None), Evaluator)):
init_kwargs = self._parse_args(new_api, args, kwargs, init_kwargs)
self.evaluator: LightningEvaluator | TorchEvaluator = init_kwargs.pop('evaluator')
if not self.evaluator._initialization_complete:
self.evaluator._init_optimizer_helpers(model) # type: ignore
self.using_evaluator = True
else:
init_kwargs = self._parse_args(old_api, args, kwargs, init_kwargs)
self.trainer: _LEGACY_TRAINER = init_kwargs.pop('trainer')
traced_optimizer: Optimizer | OptimizerConstructHelper = init_kwargs.pop('traced_optimizer')
self.criterion: _LEGACY_CRITERION = init_kwargs.pop('criterion')
if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer
else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.using_evaluator = False
warn_msg = f"The old API ...{','.join(old_api)} will be deprecated after NNI v3.0, " + \
"please using the new one ...{','.join(new_api)}"
_logger.warning(warn_msg)
return init_kwargs
def _parse_args(self, arg_names: List, args: Tuple, kwargs: Dict, def_kwargs: Dict) -> Dict:
merged_kwargs = {arg_names[idx]: arg for idx, arg in enumerate(args)}
for key, value in kwargs.items():
if key in merged_kwargs:
raise TypeError(f"{self.__class__.__name__}.__init__() got multiple values for argument '{key}'")
merged_kwargs[key] = value
for key, value in def_kwargs.items():
if key not in merged_kwargs:
merged_kwargs[key] = value
diff = set(arg_names).difference(merged_kwargs.keys())
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
diff = set(merged_kwargs.keys()).difference(arg_names)
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() got {len(diff)} unexpected keyword argument: {diff}")
return merged_kwargs
def compress(self) -> Tuple[Module, Dict]:
result = super().compress()
if self.using_evaluator:
self.evaluator.unbind_model()
return result
class LevelPruner(BasicPruner):
r"""
This is a basic pruner, and in some papers called it magnitude pruning or fine-grained pruning.
......@@ -133,9 +213,9 @@ class LevelPruner(BasicPruner):
Parameters
----------
model : torch.nn.Module
model
Model to be pruned.
config_list : List[Dict]
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
......@@ -143,7 +223,7 @@ class LevelPruner(BasicPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
mode : str
mode
'normal' or 'balance'.
If setting 'normal' mode, target tensor will be pruned in the way of finegrained pruning.
If setting 'balance' mode, a specal sparse pattern will chosen by pruner. Take linear
......@@ -152,7 +232,7 @@ class LevelPruner(BasicPruner):
pattern have more chance to achieve better trade-off between model performance and hardware
acceleration. Please refer to releated paper for further information `Balanced Sparsity for
Efficient DNN Inference on GPU <https://arxiv.org/pdf/1811.00206.pdf>`__.
balance_gran : list
balance_gran
Balance_gran is for special sparse pattern balanced sparsity, Default value is None which means pruning
without awaring balance, namely normal finegrained pruning.
If passing list of int, LevelPruner will prune the model in the granularity of multi-dimension block.
......@@ -195,7 +275,8 @@ class LevelPruner(BasicPruner):
>>> pruner = LevelPruner(model, config_list)
>>> masked_model, masks = pruner.compress()
For detailed example please refer to :githublink:`examples/model_compress/pruning/level_pruning_torch.py <examples/model_compress/pruning/level_pruning_torch.py>`
For detailed example please refer to
:githublink:`examples/model_compress/pruning/level_pruning_torch.py <examples/model_compress/pruning/level_pruning_torch.py>`
"""
def __init__(self, model: Module, config_list: List[Dict], mode: str = "normal", balance_gran: Optional[List] = None):
......@@ -209,13 +290,13 @@ class LevelPruner(BasicPruner):
schema.validate(config_list)
def reset_tools(self):
if self.data_collector is None:
self.data_collector = WeightDataCollector(self)
if not hasattr(self, 'data_collector'):
self.data_collector = TargetDataCollector(self)
else:
self.data_collector.reset()
if self.metrics_calculator is None:
if not hasattr(self, 'metrics_calculator'):
self.metrics_calculator = NormMetricsCalculator()
if self.sparsity_allocator is None:
if not hasattr(self, 'sparsity_allocator'):
if self.mode == "normal":
self.sparsity_allocator = NormalSparsityAllocator(self)
elif self.mode == "balance":
......@@ -228,9 +309,9 @@ class NormPruner(BasicPruner):
"""
Parameters
----------
model : torch.nn.Module
model
Model to be pruned.
config_list : List[Dict]
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
......@@ -238,9 +319,9 @@ class NormPruner(BasicPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
p : int
p
The order of norm.
mode : str
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the norm of weights and the channel-dependency or
......@@ -249,7 +330,7 @@ class NormPruner(BasicPruner):
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : Optional[torch.Tensor]
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
......@@ -270,19 +351,20 @@ class NormPruner(BasicPruner):
schema.validate(config_list)
def reset_tools(self):
if self.data_collector is None:
self.data_collector = WeightDataCollector(self)
else:
self.data_collector.reset()
if self.metrics_calculator is None:
self.metrics_calculator = NormMetricsCalculator(p=self.p, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
if self.sparsity_allocator is None:
scalers = Scaling(kernel_size=[1], kernel_padding_mode='back')
if not hasattr(self, 'sparsity_allocator'):
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self, Scaling(kernel_size=[1], kernel_padding_mode='back'))
self.sparsity_allocator = NormalSparsityAllocator(self, scalers)
elif self.mode == 'dependency_aware':
self.sparsity_allocator = DependencyAwareAllocator(self, self.dummy_input, Scaling(kernel_size=[1], kernel_padding_mode='back'))
self.sparsity_allocator = DependencyAwareAllocator(self, self.dummy_input, scalers)
else:
raise NotImplementedError('Only support mode `normal` and `dependency_aware`')
if not hasattr(self, 'data_collector'):
self.data_collector = TargetDataCollector(self)
else:
self.data_collector.reset()
if not hasattr(self, 'metrics_calculator'):
self.metrics_calculator = NormMetricsCalculator(p=self.p, scalers=scalers)
class L1NormPruner(NormPruner):
......@@ -298,9 +380,9 @@ class L1NormPruner(NormPruner):
Parameters
----------
model : torch.nn.Module
model
Model to be pruned.
config_list : List[Dict]
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
......@@ -308,7 +390,7 @@ class L1NormPruner(NormPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
mode : str
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the l1-norm of weights and the channel-dependency or
......@@ -317,7 +399,7 @@ class L1NormPruner(NormPruner):
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : Optional[torch.Tensor]
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
......@@ -330,15 +412,16 @@ class L1NormPruner(NormPruner):
class L2NormPruner(NormPruner):
r"""
L2 norm pruner is a variant of L1 norm pruner.
The only different between L2 norm pruner and L1 norm pruner is L2 norm pruner prunes the weight with the smallest L2 norm of the weights.
The only different between L2 norm pruner and L1 norm pruner is
L2 norm pruner prunes the weight with the smallest L2 norm of the weights.
L2 norm pruner also supports dependency-aware mode.
Parameters
----------
model : torch.nn.Module
model
Model to be pruned.
config_list : List[Dict]
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
......@@ -346,7 +429,7 @@ class L2NormPruner(NormPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
mode : str
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
......@@ -355,7 +438,7 @@ class L2NormPruner(NormPruner):
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : Optional[torch.Tensor]
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
......@@ -367,7 +450,8 @@ class L2NormPruner(NormPruner):
>>> pruner = L2NormPruner(model, config_list)
>>> masked_model, masks = pruner.compress()
For detailed example please refer to :githublink:`examples/model_compress/pruning/norm_pruning_torch.py <examples/model_compress/pruning/norm_pruning_torch.py>`
For detailed example please refer to
:githublink:`examples/model_compress/pruning/norm_pruning_torch.py <examples/model_compress/pruning/norm_pruning_torch.py>`
"""
def __init__(self, model: Module, config_list: List[Dict],
......@@ -380,15 +464,16 @@ class FPGMPruner(BasicPruner):
FPGM pruner prunes the blocks of the weight on the first dimension with the smallest geometric median.
FPGM chooses the weight blocks with the most replaceable contribution.
For more details, please refer to `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`__.
For more details, please refer to
`Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`__.
FPGM pruner also supports dependency-aware mode.
Parameters
----------
model : torch.nn.Module
model
Model to be pruned.
config_list : List[Dict]
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
......@@ -396,7 +481,7 @@ class FPGMPruner(BasicPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
mode : str
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the FPGM of weights and the channel-dependency or
......@@ -405,7 +490,7 @@ class FPGMPruner(BasicPruner):
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : Optional[torch.Tensor]
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
......@@ -417,7 +502,8 @@ class FPGMPruner(BasicPruner):
>>> pruner = FPGMPruner(model, config_list)
>>> masked_model, masks = pruner.compress()
For detailed example please refer to :githublink:`examples/model_compress/pruning/fpgm_pruning_torch.py <examples/model_compress/pruning/fpgm_pruning_torch.py>`
For detailed example please refer to
:githublink:`examples/model_compress/pruning/fpgm_pruning_torch.py <examples/model_compress/pruning/fpgm_pruning_torch.py>`
"""
def __init__(self, model: Module, config_list: List[Dict],
......@@ -435,33 +521,33 @@ class FPGMPruner(BasicPruner):
schema.validate(config_list)
def reset_tools(self):
if self.data_collector is None:
self.data_collector = WeightDataCollector(self)
else:
self.data_collector.reset()
if self.metrics_calculator is None:
self.metrics_calculator = DistMetricsCalculator(p=2, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
if self.sparsity_allocator is None:
scalers = Scaling(kernel_size=[1], kernel_padding_mode='back')
if not hasattr(self, 'sparsity_allocator'):
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self, Scaling(kernel_size=[1], kernel_padding_mode='back'))
self.sparsity_allocator = NormalSparsityAllocator(self, scalers)
elif self.mode == 'dependency_aware':
self.sparsity_allocator = DependencyAwareAllocator(self, self.dummy_input, Scaling(kernel_size=[1], kernel_padding_mode='back'))
self.sparsity_allocator = DependencyAwareAllocator(self, self.dummy_input, scalers)
else:
raise NotImplementedError('Only support mode `normal` and `dependency_aware`')
if not hasattr(self, 'data_collector'):
self.data_collector = TargetDataCollector(self)
else:
self.data_collector.reset()
if not hasattr(self, 'metrics_calculator'):
self.metrics_calculator = DistMetricsCalculator(p=2, scalers=scalers)
class SlimPruner(BasicPruner):
r"""
Slim pruner adds sparsity regularization on the scaling factors of batch normalization (BN) layers during training to identify unimportant channels.
The channels with small scaling factor values will be pruned.
class SlimPruner(EvaluatorBasedPruner):
__doc__ = r"""Slim pruner adds sparsity regularization on the scaling factors of batch normalization (BN) layers during training
to identify unimportant channels. The channels with small scaling factor values will be pruned.
For more details, please refer to `Learning Efficient Convolutional Networks through Network Slimming <https://arxiv.org/abs/1708.06519>`__\.
Parameters
----------
model : torch.nn.Module
model
Model to be pruned.
config_list : List[Dict]
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
......@@ -471,68 +557,46 @@ class SlimPruner(BasicPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer : Callable[[Module, Optimizer, Callable], None]
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_epochs : int
evaluator
``evaluator`` is used to replace the previous ``trainer``, ``traced_optimizer`` and ``criterion`` API.
{evaluator_docstring}
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
training_epochs
The epoch number for training model to sparsify the BN weight.
scale : float
scale
Penalty parameter for sparsification, which could reduce overfitting.
mode : str
mode
'normal' or 'global'.
If prune the model in a global way, all layer weights with same config will be considered uniformly.
That means a single layer may not reach or exceed the sparsity setting in config,
but the total pruned weights meet the sparsity setting.
Examples
--------
>>> import nni
>>> from nni.compression.pytorch.pruning import SlimPruner
>>> model = ...
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
>>> trainer = ...
>>> criterion = ...
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['BatchNorm2d'] }]
>>> pruner = SlimPruner(model, config_list, trainer, traced_optimizer, criterion, training_epochs=1)
>>> masked_model, masks = pruner.compress()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/slim_pruning_torch.py <examples/model_compress/pruning/slim_pruning_torch.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor],
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator,
training_epochs: int, scale: float = 0.0001, mode='global'):
self.mode = mode
self.trainer = trainer
if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer
else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion
self.training_epochs = training_epochs
self._scale = scale
...
@overload
def __init__(self, model: Module, config_list: List[Dict], trainer: _LEGACY_TRAINER, traced_optimizer: Optimizer,
criterion: _LEGACY_CRITERION, training_epochs: int, scale: float = 0.0001, mode='global'):
...
def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
# TODO: remove in nni v3.0. Fake overload.
new_api = ['evaluator', 'training_epochs', 'scale', 'mode']
old_api = ['trainer', 'traced_optimizer', 'criterion', 'training_epochs', 'scale', 'mode']
init_kwargs = {'scale': 0.0001, 'mode': 'global'}
init_kwargs = self._init_evaluator(model, new_api, old_api, init_kwargs, args, kwargs)
self.training_epochs, self._scale, self.mode = init_kwargs['training_epochs'], init_kwargs['scale'], init_kwargs['mode']
super().__init__(model, config_list)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
......@@ -549,26 +613,48 @@ class SlimPruner(BasicPruner):
schema.validate(config_list)
except SchemaError as e:
if "Missing key: 'total_sparsity'" in str(e):
_logger.error('`config_list` validation failed. If global mode is set in this pruner, `sparsity_per_layer` and `sparsity` are not supported, make sure `total_sparsity` is set in config_list.')
err_msg = '`config_list` validation failed. If global mode is set in this pruner, ' + \
'`sparsity_per_layer` and `sparsity` are not supported, make sure `total_sparsity` is set in config_list.'
_logger.error(err_msg)
raise e
# TODO: remove in nni v3.0.
def criterion_patch(self, criterion: Callable[[Tensor, Tensor], Tensor]) -> Callable[[Tensor, Tensor], Tensor]:
def patched_criterion(input_tensor: Tensor, target: Tensor):
sum_l1 = 0
for wrapper in self.get_modules_wrapper().values():
sum_l1 += torch.norm(wrapper.module.weight, p=1) # type: ignore
sum_l1 += torch.norm(wrapper.weight, p=1) # type: ignore
return criterion(input_tensor, target) + self._scale * sum_l1
return patched_criterion
def loss_patch(self, origin_loss: Tensor):
# additional weight norm loss in Slim, used to sparse the weight value.
sum_l1 = 0
for wrapper in self.get_modules_wrapper().values():
target_name = 'weight'
sum_l1 += torch.norm(getattr(wrapper, target_name), p=1) # type: ignore
return self._scale * sum_l1 + origin_loss
def reset_tools(self):
if self.data_collector is None:
self.data_collector = WeightTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
self.training_epochs, criterion_patch=self.criterion_patch)
if self.using_evaluator:
# TODO: move to other place in nni v3.0
self.evaluator.unbind_model()
self.evaluator.bind_model(self.bound_model, self.get_origin2wrapped_parameter_name_map()) # type: ignore
if not hasattr(self, 'data_collector'):
self.data_collector = EvaluatorBasedTargetDataCollector(self, self.evaluator, loss_patch=self.loss_patch,
max_epochs=self.training_epochs)
else:
self.data_collector.reset(loss_patch=self.loss_patch)
else:
self.data_collector.reset()
if self.metrics_calculator is None:
if not hasattr(self, 'data_collector'):
self.data_collector = WeightTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
self.training_epochs, criterion_patch=self.criterion_patch)
else:
self.data_collector.reset()
if not hasattr(self, 'metrics_calculator'):
self.metrics_calculator = NormMetricsCalculator()
if self.sparsity_allocator is None:
if not hasattr(self, 'sparsity_allocator'):
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self)
elif self.mode == 'global':
......@@ -577,13 +663,12 @@ class SlimPruner(BasicPruner):
raise NotImplementedError('Only support mode `normal` and `global`')
class ActivationPruner(BasicPruner):
"""
Parameters
class ActivationPruner(EvaluatorBasedPruner):
__doc__ = r"""Parameters
----------
model : torch.nn.Module
model
Model to be pruned.
config_list : List[Dict]
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
......@@ -591,33 +676,15 @@ class ActivationPruner(BasicPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer : Callable[[Module, Optimizer, Callable], None]
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_batches
The batch number used to collect activations.
mode : str
evaluator
``evaluator`` is used to replace the previous ``trainer``, ``traced_optimizer`` and ``criterion`` API.
{evaluator_docstring}
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
training_steps
The step number used to collect activations.
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the activation-based metrics and the channel-dependency or
......@@ -626,24 +693,34 @@ class ActivationPruner(BasicPruner):
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : Optional[torch.Tensor]
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int, activation: str = 'relu',
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, training_steps: int,
activation: str = 'relu', mode: str = 'normal', dummy_input: Optional[Tensor] = None):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], trainer: _LEGACY_TRAINER, traced_optimizer: Optimizer,
criterion: _LEGACY_CRITERION, training_batches: int, activation: str = 'relu',
mode: str = 'normal', dummy_input: Optional[Tensor] = None):
self.mode = mode
self.dummy_input = dummy_input
self.trainer = trainer
if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer
else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion
self.training_batches = training_batches
self._activation = self._choose_activation(activation)
...
def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
# TODO: remove in nni v3.0. Fake overload.
new_api = ['evaluator', 'training_steps', 'activation', 'mode', 'dummy_input']
old_api = ['trainer', 'traced_optimizer', 'criterion', 'training_batches', 'activation', 'mode', 'dummy_input']
init_kwargs = {'activation': 'relu', 'mode': 'normal', 'dummy_input': None}
init_kwargs = self._init_evaluator(model, new_api, old_api, init_kwargs, args, kwargs)
self.training_steps: int = init_kwargs.get('training_steps', init_kwargs.get('training_batches'))
self._activation: Callable[[Tensor], Tensor] = self._choose_activation(init_kwargs['activation'])
self.mode: str = init_kwargs['mode']
self.dummy_input = init_kwargs['dummy_input']
super().__init__(model, config_list)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
......@@ -670,10 +747,11 @@ class ActivationPruner(BasicPruner):
buffer.append(0)
def collect_activation(_module: Module, _input: Tensor, output: Tensor):
activation = self._activation_trans(output)
if len(buffer) == 1:
buffer.append(torch.zeros_like(output))
if buffer[0] < self.training_batches:
buffer[1] += self._activation_trans(output)
buffer.append(torch.zeros_like(activation))
if buffer[0] < self.training_steps:
buffer[1] += activation
buffer[0] += 1
return collect_activation
......@@ -681,42 +759,60 @@ class ActivationPruner(BasicPruner):
raise NotImplementedError()
def reset_tools(self):
collector_info = HookCollectorInfo([layer_info for layer_info, _ in self._detect_modules_to_compress()], 'forward', self._collector)
if self.data_collector is None:
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
1, collector_infos=[collector_info])
else:
self.data_collector.reset(collector_infos=[collector_info]) # type: ignore
if self.metrics_calculator is None:
self.metrics_calculator = self._create_metrics_calculator()
if self.sparsity_allocator is None:
scalers = Scaling(kernel_size=[1], kernel_padding_mode='back')
if not hasattr(self, 'sparsity_allocator'):
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self, Scaling(kernel_size=[1], kernel_padding_mode='back'))
self.sparsity_allocator = NormalSparsityAllocator(self, scalers)
elif self.mode == 'dependency_aware':
self.sparsity_allocator = DependencyAwareAllocator(self, self.dummy_input, Scaling(kernel_size=[1], kernel_padding_mode='back'))
self.sparsity_allocator = DependencyAwareAllocator(self, self.dummy_input, scalers)
else:
raise NotImplementedError('Only support mode `normal` and `dependency_aware`')
if self.using_evaluator:
# TODO: move to other place in nni v3.0
self.evaluator.unbind_model()
self.evaluator.bind_model(self.bound_model, self.get_origin2wrapped_parameter_name_map()) # type: ignore
forward_hooks = {}
for module_name, wrapper in self.get_modules_wrapper().items():
target_name = 'weight'
forward_hooks[module_name] = {target_name: ForwardHook(wrapper, module_name, self._collector)}
if not hasattr(self, 'data_collector'):
self.data_collector = EvaluatorBasedHookDataCollector(self, self.evaluator, hooks=forward_hooks,
max_steps=self.training_steps)
else:
self.data_collector.reset(hooks=forward_hooks)
else:
collector_info = HookCollectorInfo([layer_info for layer_info, _ in self._detect_modules_to_compress()],
'forward', self._collector)
if not hasattr(self, 'data_collector'):
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
1, collector_infos=[collector_info])
else:
self.data_collector.reset([collector_info]) # type: ignore
if not hasattr(self, 'metrics_calculator'):
self.metrics_calculator = self._create_metrics_calculator()
def _create_metrics_calculator(self) -> MetricsCalculator:
raise NotImplementedError()
class ActivationAPoZRankPruner(ActivationPruner):
r"""
Activation APoZ rank pruner is a pruner which prunes on the first weight dimension,
__doc__ = r"""Activation APoZ rank pruner is a pruner which prunes on the first weight dimension,
with the smallest importance criterion ``APoZ`` calculated from the output activations of convolution layers to achieve a preset level of network sparsity.
The pruning criterion ``APoZ`` is explained in the paper `Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures <https://arxiv.org/abs/1607.03250>`__.
The APoZ is defined as:
:math:`APoZ_{c}^{(i)} = APoZ\left(O_{c}^{(i)}\right)=\frac{\sum_{k}^{N} \sum_{j}^{M} f\left(O_{c, j}^{(i)}(k)=0\right)}{N \times M}`
""" + r"""
Activation APoZ rank pruner also supports dependency-aware mode.
Parameters
----------
model : torch.nn.Module
model
Model to be pruned.
config_list : List[Dict]
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
......@@ -724,33 +820,15 @@ class ActivationAPoZRankPruner(ActivationPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer : Callable[[Module, Optimizer, Callable], None]
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``..
criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_batches
The batch number used to collect activations.
mode : str
evaluator
``evaluator`` is used to replace the previous ``trainer``, ``traced_optimizer`` and ``criterion`` API.
{evaluator_docstring}
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
training_steps
The step number used to collect activations.
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the activation-based metrics and the channel-dependency or
......@@ -759,35 +837,25 @@ class ActivationAPoZRankPruner(ActivationPruner):
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : Optional[torch.Tensor]
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
Examples
--------
>>> import nni
>>> from nni.compression.pytorch.pruning import ActivationAPoZRankPruner
>>> model = ...
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
>>> trainer = ...
>>> criterion = ...
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
>>> pruner = ActivationAPoZRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
>>> masked_model, masks = pruner.compress()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/activation_pruning_torch.py <examples/model_compress/pruning/activation_pruning_torch.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
def _activation_trans(self, output: Tensor) -> Tensor:
# return a matrix that the position of zero in `output` is one, others is zero.
return torch.eq(self._activation(output.detach()), torch.zeros_like(output)).type_as(output)
return torch.eq(self._activation(output.detach()), torch.zeros_like(output)).type_as(output).mean(0)
def _create_metrics_calculator(self) -> MetricsCalculator:
return APoZRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back'))
return APoZRankMetricsCalculator(Scaling(kernel_size=[1], kernel_padding_mode='back'))
class ActivationMeanRankPruner(ActivationPruner):
r"""
__doc__ = r"""
Activation mean rank pruner is a pruner which prunes on the first weight dimension,
with the smallest importance criterion ``mean activation`` calculated from the output activations of convolution layers to achieve a preset level of network sparsity.
......@@ -797,9 +865,9 @@ class ActivationMeanRankPruner(ActivationPruner):
Parameters
----------
model : torch.nn.Module
model
Model to be pruned.
config_list : List[Dict]
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
......@@ -807,33 +875,15 @@ class ActivationMeanRankPruner(ActivationPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer : Callable[[Module, Optimizer, Callable], None]
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``..
criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_batches
The batch number used to collect activations.
mode : str
evaluator
``evaluator`` is used to replace the previous ``trainer``, ``traced_optimizer`` and ``criterion`` API.
{evaluator_docstring}
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
training_steps
The step number used to collect activations.
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the activation-based metrics and the channel-dependency or
......@@ -842,40 +892,31 @@ class ActivationMeanRankPruner(ActivationPruner):
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : Optional[torch.Tensor]
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
Examples
--------
>>> import nni
>>> from nni.compression.pytorch.pruning import ActivationMeanRankPruner
>>> model = ...
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
>>> trainer = ...
>>> criterion = ...
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
>>> pruner = ActivationMeanRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
>>> masked_model, masks = pruner.compress()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/activation_pruning_torch.py <examples/model_compress/pruning/activation_pruning_torch.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
def _activation_trans(self, output: Tensor) -> Tensor:
# return the activation of `output` directly.
return self._activation(output.detach())
return self._activation(output.detach()).mean(0)
def _create_metrics_calculator(self) -> MetricsCalculator:
return MeanRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back'))
return MeanRankMetricsCalculator(Scaling(kernel_size=[1], kernel_padding_mode='back'))
class TaylorFOWeightPruner(BasicPruner):
r"""
class TaylorFOWeightPruner(EvaluatorBasedPruner):
__doc__ = r"""
Taylor FO weight pruner is a pruner which prunes on the first weight dimension,
based on estimated importance calculated from the first order taylor expansion on weights to achieve a preset level of network sparsity.
The estimated importance is defined as the paper `Importance Estimation for Neural Network Pruning <http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf>`__.
:math:`\widehat{\mathcal{I}}_{\mathcal{S}}^{(1)}(\mathbf{W}) \triangleq \sum_{s \in \mathcal{S}} \mathcal{I}_{s}^{(1)}(\mathbf{W})=\sum_{s \in \mathcal{S}}\left(g_{s} w_{s}\right)^{2}`
""" + r"""
Taylor FO weight pruner also supports dependency-aware mode.
......@@ -883,9 +924,9 @@ class TaylorFOWeightPruner(BasicPruner):
Parameters
----------
model : torch.nn.Module
model
Model to be pruned.
config_list : List[Dict]
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
......@@ -895,33 +936,15 @@ class TaylorFOWeightPruner(BasicPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer : Callable[[Module, Optimizer, Callable]
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_batches : int
The batch number used to collect activations.
mode : str
evaluator
``evaluator`` is used to replace the previous ``trainer``, ``traced_optimizer`` and ``criterion`` API.
{evaluator_docstring}
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
training_steps
The step number used to collect activations.
mode
'normal', 'dependency_aware' or 'global'.
If prune the model in a dependency-aware way, this pruner will
......@@ -935,38 +958,36 @@ class TaylorFOWeightPruner(BasicPruner):
If prune the model in a global way, all layer weights with same config will be considered uniformly.
That means a single layer may not reach or exceed the sparsity setting in config,
but the total pruned weights meet the sparsity setting.
dummy_input : Optional[torch.Tensor]
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
Examples
--------
>>> import nni
>>> from nni.compression.pytorch.pruning import TaylorFOWeightPruner
>>> model = ...
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
>>> trainer = ...
>>> criterion = ...
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
>>> pruner = TaylorFOWeightPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
>>> masked_model, masks = pruner.compress()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/taylorfo_pruning_torch.py <examples/model_compress/pruning/taylorfo_pruning_torch.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int,
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, training_steps: int,
mode: str = 'normal', dummy_input: Optional[Tensor] = None):
self.mode = mode
self.dummy_input = dummy_input
self.trainer = trainer
if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer
else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion
self.training_batches = training_batches
...
@overload
def __init__(self, model: Module, config_list: List[Dict], trainer: _LEGACY_TRAINER, traced_optimizer: Optimizer,
criterion: _LEGACY_CRITERION, training_batches: int, mode: str = 'normal', dummy_input: Optional[Tensor] = None):
...
def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
# TODO: remove in nni v3.0. Fake overload.
new_api = ['evaluator', 'training_steps', 'mode', 'dummy_input']
old_api = ['trainer', 'traced_optimizer', 'criterion', 'training_batches', 'mode', 'dummy_input']
init_kwargs = {'mode': 'normal', 'dummy_input': None}
init_kwargs = self._init_evaluator(model, new_api, old_api, init_kwargs, args, kwargs)
self.training_steps: int = init_kwargs.get('training_steps', init_kwargs.get('training_batches'))
self.mode: str = init_kwargs['mode']
self.dummy_input = init_kwargs['dummy_input']
super().__init__(model, config_list)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
......@@ -983,16 +1004,19 @@ class TaylorFOWeightPruner(BasicPruner):
schema.validate(config_list)
except SchemaError as e:
if "Missing key: 'total_sparsity'" in str(e):
_logger.error('`config_list` validation failed. If global mode is set in this pruner, `sparsity_per_layer` and `sparsity` are not supported, make sure `total_sparsity` is set in config_list.')
err_msg = '`config_list` validation failed. If global mode is set in this pruner, ' + \
'`sparsity_per_layer` and `sparsity` are not supported, make sure `total_sparsity` is set in config_list.'
_logger.error(err_msg)
raise e
def _collector(self, buffer: List, weight_tensor: Tensor) -> Callable[[Tensor], None]:
assert len(buffer) == 0, 'Buffer pass to taylor pruner collector is not empty.'
buffer.append(0)
buffer.append(torch.zeros_like(weight_tensor))
def collect_taylor(grad: Tensor):
if buffer[0] < self.training_batches:
if len(buffer) == 1:
buffer.append(torch.zeros_like(grad))
if buffer[0] < self.training_steps:
buffer[1] += self._calculate_taylor_expansion(weight_tensor, grad)
buffer[0] += 1
return collect_taylor
......@@ -1001,28 +1025,47 @@ class TaylorFOWeightPruner(BasicPruner):
return (weight_tensor.detach() * grad.detach()).data.pow(2)
def reset_tools(self):
hook_targets = {name: wrapper.weight for name, wrapper in self.get_modules_wrapper().items()} # type: ignore
collector_info = HookCollectorInfo(hook_targets, 'tensor', self._collector) # type: ignore
if self.data_collector is None:
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
1, collector_infos=[collector_info])
else:
self.data_collector.reset(collector_infos=[collector_info]) # type: ignore
if self.metrics_calculator is None:
self.metrics_calculator = MultiDataNormMetricsCalculator(p=1, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
if self.sparsity_allocator is None:
scalers = Scaling(kernel_size=[1], kernel_padding_mode='back')
if not hasattr(self, 'sparsity_allocator'):
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self, Scaling(kernel_size=[1], kernel_padding_mode='back'))
self.sparsity_allocator = NormalSparsityAllocator(self, scalers)
elif self.mode == 'global':
self.sparsity_allocator = GlobalSparsityAllocator(self, Scaling(kernel_size=[1], kernel_padding_mode='back'))
self.sparsity_allocator = GlobalSparsityAllocator(self, scalers)
elif self.mode == 'dependency_aware':
self.sparsity_allocator = DependencyAwareAllocator(self, self.dummy_input, Scaling(kernel_size=[1], kernel_padding_mode='back'))
self.sparsity_allocator = DependencyAwareAllocator(self, self.dummy_input, scalers)
else:
raise NotImplementedError('Only support mode `normal`, `global` and `dependency_aware`')
if self.using_evaluator:
# TODO: move to other place in nni v3.0
self.evaluator.unbind_model()
self.evaluator.bind_model(self.bound_model, self.get_origin2wrapped_parameter_name_map()) # type: ignore
tensor_hooks = {}
for module_name, wrapper in self.get_modules_wrapper().items():
target_name = 'weight'
target = getattr(wrapper, target_name)
tensor_hooks[module_name] = {target_name: TensorHook(target, module_name,
functools.partial(self._collector, weight_tensor=target))}
if not hasattr(self, 'data_collector'):
self.data_collector = EvaluatorBasedHookDataCollector(self, self.evaluator, hooks=tensor_hooks,
max_steps=self.training_steps)
else:
self.data_collector.reset(hooks=tensor_hooks)
else:
hook_targets = {name: wrapper.weight for name, wrapper in self.get_modules_wrapper().items()} # type: ignore
collector_info = HookCollectorInfo(hook_targets, 'tensor', self._collector) # type: ignore
if not hasattr(self, 'data_collector'):
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
1, collector_infos=[collector_info])
else:
self.data_collector.reset([collector_info]) # type: ignore
class ADMMPruner(BasicPruner):
r"""
if not hasattr(self, 'metrics_calculator'):
self.metrics_calculator = HookDataNormMetricsCalculator(p=1, scalers=scalers)
class ADMMPruner(EvaluatorBasedPruner):
__doc__ = r"""
Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique,
by decomposing the original nonconvex problem into two subproblems that can be solved iteratively.
In weight pruning problem, these two subproblems are solved via 1) gradient descent algorithm and 2) Euclidean projection respectively.
......@@ -1036,9 +1079,9 @@ class ADMMPruner(BasicPruner):
Parameters
----------
model : torch.nn.Module
model
Model to be pruned.
config_list : List[Dict]
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
......@@ -1047,77 +1090,60 @@ class ADMMPruner(BasicPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer : Callable[[Module, Optimizer, Callable]
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss.
iterations : int
evaluator
``evaluator`` is used to replace the previous ``trainer``, ``traced_optimizer`` and ``criterion`` API.
{evaluator_docstring}
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
iterations
The total iteration number in admm pruning algorithm.
training_epochs : int
training_epochs
The epoch number for training model in each iteration.
granularity : str
granularity
'fine-grained' or 'coarse-grained'.
If 'coarse-grained' is set, ADMM pruner will generate masks on output channels wise.
In original admm pruning paper, author implemented a fine-grained admm pruning.
In auto-compress paper, author used coarse-grained admm pruning.
Examples
--------
>>> import nni
>>> from nni.compression.pytorch.pruning import ADMMPruner
>>> model = ...
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
>>> trainer = ...
>>> criterion = ...
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
>>> pruner = ADMMPruner(model, config_list, trainer, traced_optimizer, criterion, iterations=10, training_epochs=1)
>>> masked_model, masks = pruner.compress()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/admm_pruning_torch.py <examples/model_compress/pruning/admm_pruning_torch.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int,
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, iterations: int,
training_epochs: int, granularity: str = 'fine-grained'):
self.trainer = trainer
if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer
else:
assert model is not None, 'Model is required if traced_optimizer is provided.'
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion
self.iterations = iterations
self.training_epochs = training_epochs
assert granularity in ['fine-grained', 'coarse-grained']
self.granularity = granularity
...
@overload
def __init__(self, model: Module, config_list: List[Dict], trainer: _LEGACY_TRAINER,
traced_optimizer: Optimizer, criterion: _LEGACY_CRITERION, iterations: int,
training_epochs: int, granularity: str = 'fine-grained'):
...
def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
# TODO: remove in nni v3.0. Fake overload.
new_api = ['evaluator', 'iterations', 'training_epochs', 'granularity']
old_api = ['trainer', 'traced_optimizer', 'criterion', 'iterations', 'training_epochs', 'granularity']
init_kwargs = {'granularity': 'fine-grained'}
init_kwargs = self._init_evaluator(model, new_api, old_api, init_kwargs, args, kwargs)
self.iterations: int = init_kwargs['iterations']
self.training_epochs: int = init_kwargs['training_epochs']
assert init_kwargs['granularity'] in ['fine-grained', 'coarse-grained']
self.granularity: str = init_kwargs['granularity']
self.Z, self.U = {}, {}
super().__init__(model, config_list)
def reset(self, model: Module, config_list: List[Dict]):
super().reset(model, config_list)
self.Z = {name: wrapper.module.weight.data.clone().detach() for name, wrapper in self.get_modules_wrapper().items()} # type: ignore
self.U = {name: torch.zeros_like(z).to(z.device) for name, z in self.Z.items()}
# FIXME: Only support pruning 'weight' right now.
target_name = 'weight'
for module_name, wrapper in self.get_modules_wrapper().items():
self.Z[module_name] = {target_name: wrapper.weight.data.clone()} # type: ignore
self.U = {module_name: {target_name: torch.zeros_like(z[target_name])} for module_name, z in self.Z.items()}
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
......@@ -1127,53 +1153,72 @@ class ADMMPruner(BasicPruner):
schema = CompressorSchema(schema_list, model, _logger)
schema.validate(config_list)
# TODO: remove in nni v3.0.
def criterion_patch(self, origin_criterion: Callable[[Tensor, Tensor], Tensor]):
def patched_criterion(output: Tensor, target: Tensor):
penalty = torch.tensor(0.0).to(output.device)
for name, wrapper in self.get_modules_wrapper().items():
rho = wrapper.config.get('rho', 1e-4)
penalty += (rho / 2) * torch.sqrt(torch.norm(wrapper.module.weight - self.Z[name] + self.U[name]))
self.Z[name]['weight'] = self.Z[name]['weight'].to(wrapper.weight.device) # type: ignore
self.U[name]['weight'] = self.U[name]['weight'].to(wrapper.weight.device) # type: ignore
penalty += (rho / 2) * torch.sqrt(torch.norm(wrapper.weight - self.Z[name]['weight'] + self.U[name]['weight']))
return origin_criterion(output, target) + penalty
return patched_criterion
def loss_patch(self, origin_loss: Tensor):
penalty = 0
for name, wrapper in self.get_modules_wrapper().items():
rho = wrapper.config.get('rho', 1e-4)
self.Z[name]['weight'] = self.Z[name]['weight'].to(wrapper.weight.device) # type: ignore
self.U[name]['weight'] = self.U[name]['weight'].to(wrapper.weight.device) # type: ignore
penalty += (rho / 2) * torch.sqrt(torch.norm(wrapper.weight - self.Z[name]['weight'] + self.U[name]['weight']))
return origin_loss + penalty
def reset_tools(self):
if self.data_collector is None:
self.data_collector = WeightTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
self.training_epochs, criterion_patch=self.criterion_patch)
if self.using_evaluator:
# TODO: move to other place in nni v3.0
self.evaluator.unbind_model()
self.evaluator.bind_model(self.bound_model, self.get_origin2wrapped_parameter_name_map()) # type: ignore
if not hasattr(self, 'data_collector'):
self.data_collector = EvaluatorBasedTargetDataCollector(self, self.evaluator, loss_patch=self.loss_patch,
max_epochs=self.training_epochs)
else:
self.data_collector.reset(loss_patch=self.loss_patch)
else:
self.data_collector.reset()
if self.metrics_calculator is None:
if not hasattr(self, 'data_collector'):
self.data_collector = WeightTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
self.training_epochs, criterion_patch=self.criterion_patch)
else:
self.data_collector.reset()
if not hasattr(self, 'metrics_calculator'):
if self.granularity == 'fine-grained':
self.metrics_calculator = NormMetricsCalculator(p=1)
elif self.granularity == 'coarse-grained':
self.metrics_calculator = NormMetricsCalculator(p=1, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
if self.sparsity_allocator is None:
if not hasattr(self, 'sparsity_allocator'):
if self.granularity == 'fine-grained':
self.sparsity_allocator = NormalSparsityAllocator(self)
elif self.granularity == 'coarse-grained':
self.sparsity_allocator = NormalSparsityAllocator(self, Scaling(kernel_size=[1], kernel_padding_mode='back'))
def compress(self) -> Tuple[Module, Dict]:
"""
Returns
-------
Tuple[Module, Dict]
Return the wrapped model and mask.
"""
assert self.bound_model is not None
assert self.data_collector is not None and self.metrics_calculator is not None and self.sparsity_allocator is not None
for i in range(self.iterations):
_logger.info('======= ADMM Iteration %d Start =======', i)
data = self.data_collector.collect()
for name, weight in data.items():
self.Z[name] = weight + self.U[name]
for module_name, targets_data in data.items():
for target_name, target_data in targets_data.items():
self.U[module_name][target_name] = self.U[module_name][target_name].to(target_data.device)
self.Z[module_name][target_name] = target_data + self.U[module_name][target_name]
metrics = self.metrics_calculator.calculate_metrics(self.Z)
masks = self.sparsity_allocator.generate_sparsity(metrics)
for name, mask in masks.items():
self.Z[name] = self.Z[name].mul(mask['weight'])
self.U[name] = self.U[name] + data[name] - self.Z[name]
for module_name, targets_mask in masks.items():
target_name = 'weight'
self.Z[module_name][target_name] = self.Z[module_name][target_name].mul(targets_mask[target_name])
self.U[module_name][target_name] = self.U[module_name][target_name] + data[module_name][target_name] - \
self.Z[module_name][target_name]
self.Z, self.U = {}, {}
torch.cuda.empty_cache()
......@@ -1182,4 +1227,8 @@ class ADMMPruner(BasicPruner):
masks = self.sparsity_allocator.generate_sparsity(metrics)
self.load_masks(masks)
if self.using_evaluator:
self.evaluator.unbind_model()
return self.bound_model, masks
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy
from typing import Dict, List, Tuple, Callable, Optional, Union
import logging
from typing import Any, Dict, List, Tuple, Callable, Optional, Union, overload
import torch
from torch import Tensor
......@@ -12,9 +15,63 @@ from nni.algorithms.compression.v2.pytorch.base import Pruner, BasePruningSchedu
from nni.compression.pytorch.speedup import ModelSpeedup
from .tools import TaskGenerator
from ..utils import Evaluator, LightningEvaluator, TorchEvaluator
_logger = logging.getLogger(__name__)
_LEGACY_FINETUNER = Callable[[Module], None]
_LEGACY_EVALUATOR = Callable[[Module], float]
class PruningScheduler(BasePruningScheduler):
# TODO: remove in nni v3.0.
class EvaluatorBasedPruningScheduler(BasePruningScheduler):
evaluator: LightningEvaluator | TorchEvaluator
using_evaluator: bool
finetuner: _LEGACY_FINETUNER
_evaluator: _LEGACY_EVALUATOR
dummy_input: Any
def _init_evaluator(self, model: Module, new_api: List[str], new_init_kwargs: Dict, old_api: List[str],
old_init_kwargs: Dict, args: Tuple, kwargs: Dict) -> Dict:
# for fake __init__ overload, parsing args and kwargs,
# initializing evaluator or [finetuner, evaluator, dummy_input], return the remaining arguments.
if (len(args) > 0 and isinstance(args[0], Evaluator)) or \
(len(args) == 0 and isinstance(kwargs.get('evaluator', None), Evaluator)):
init_kwargs = self._parse_args(new_api, args, kwargs, new_init_kwargs)
self.evaluator: LightningEvaluator | TorchEvaluator = init_kwargs.pop('evaluator')
if not self.evaluator._initialization_complete:
self.evaluator._init_optimizer_helpers(model) # type: ignore
self.using_evaluator = True
else:
init_kwargs = self._parse_args(old_api, args, kwargs, old_init_kwargs)
self.finetuner: _LEGACY_FINETUNER = init_kwargs.pop('finetuner')
self._evaluator: _LEGACY_EVALUATOR = init_kwargs.pop('evaluator')
self.dummy_input = init_kwargs.pop('dummy_input')
self.using_evaluator = False
warn_msg = f'The old API ...{",".join(old_api)} will be deprecated after NNI v3.0,' +\
f'please using the new one ...{",".join(new_api)}'
_logger.warning(warn_msg)
return init_kwargs
def _parse_args(self, arg_names: List, args: Tuple, kwargs: Dict, def_kwargs: Dict) -> Dict:
merged_kwargs = {arg_names[idx]: arg for idx, arg in enumerate(args)}
for key, value in kwargs.items():
if key in merged_kwargs:
raise TypeError(f"{self.__class__.__name__}.__init__() got multiple values for argument '{key}'")
merged_kwargs[key] = value
for key, value in def_kwargs.items():
if key not in merged_kwargs:
merged_kwargs[key] = value
diff = set(arg_names).difference(merged_kwargs.keys())
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
diff = set(merged_kwargs.keys()).difference(arg_names)
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() got {len(diff)} unexpected keyword argument: {diff}")
return merged_kwargs
class PruningScheduler(EvaluatorBasedPruningScheduler):
"""
Parameters
----------
......@@ -25,7 +82,8 @@ class PruningScheduler(BasePruningScheduler):
Used to generate task for each iteration.
finetuner
The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise.
It will be called at the end of each iteration if reset_weight is False,
will be called at the beginning of each iteration otherwise.
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input
......@@ -36,16 +94,30 @@ class PruningScheduler(BasePruningScheduler):
reset_weight
If set True, the model weight will reset to the origin model weight at the end of each iteration step.
"""
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Optional[Callable[[Module], None]] = None,
speedup: bool = False, dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None,
@overload
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, evaluator: LightningEvaluator | TorchEvaluator,
speedup: bool = False, reset_weight: bool = False):
...
@overload
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: _LEGACY_FINETUNER | None = None,
speedup: bool = False, dummy_input: Optional[Tensor] = None, evaluator: _LEGACY_EVALUATOR | None = None,
reset_weight: bool = False):
...
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, *args, **kwargs) -> None:
# TODO: remove in nni v3.0. Fake overload.
new_api = ['evaluator', 'speedup', 'reset_weight']
new_init_kwargs = {'evaluator': None, 'speedup': False, 'reset_weight': False}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'reset_weight']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False, 'reset_weight': False}
init_kwargs = self._init_evaluator(None, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs) # type: ignore
self.pruner = pruner
self.task_generator = task_generator
self.finetuner = finetuner
self.speedup = speedup
self.dummy_input = dummy_input
self.evaluator = evaluator
self.reset_weight = reset_weight
self.speedup = init_kwargs['speedup']
self.reset_weight = init_kwargs['reset_weight']
def reset(self, model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}):
self.task_generator.reset(model, config_list, masks)
......@@ -61,6 +133,7 @@ class PruningScheduler(BasePruningScheduler):
generate masks -> speedup -> finetune -> evaluate
"""
model, masks, config_list = task.load_data()
self.pruner.reset(model, config_list)
self.pruner.load_masks(masks)
......@@ -74,28 +147,58 @@ class PruningScheduler(BasePruningScheduler):
# speedup
if self.speedup and task.speedup:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
if self.using_evaluator:
ModelSpeedup(compact_model, self.evaluator.get_dummy_input(), pruner_generated_masks).speedup_model()
compact_model_masks = {}
else:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
# finetune
if self.finetuner is not None and task.finetune:
if self.speedup:
self.finetuner(compact_model)
else:
self.pruner._wrap_model()
self.finetuner(compact_model)
self.pruner._unwrap_model()
if self.using_evaluator:
if task.finetune:
self.evaluator.bind_model(compact_model) # type: ignore
if self.speedup:
self.evaluator.finetune()
else:
self.pruner._wrap_model()
self.evaluator.finetune()
self.pruner._unwrap_model()
self.evaluator.unbind_model()
else:
if self.finetuner is not None and task.finetune:
if self.speedup:
self.finetuner(compact_model)
else:
self.pruner._wrap_model()
self.finetuner(compact_model)
self.pruner._unwrap_model()
# evaluate
if self.evaluator is not None and task.evaluate:
if self.speedup:
score = self.evaluator(compact_model)
if self.using_evaluator:
if task.evaluate:
self.evaluator.bind_model(compact_model) # type: ignore
# TODO: support saving customized score
if self.speedup:
score = self.evaluator.evaluate()
else:
self.pruner._wrap_model()
score = self.evaluator.evaluate()
self.pruner._unwrap_model()
score = score[0] if isinstance(score, tuple) else score
self.evaluator.unbind_model()
else:
self.pruner._wrap_model()
score = self.evaluator(compact_model)
self.pruner._unwrap_model()
score = None
else:
score = None
if self._evaluator is not None and task.evaluate:
if self.speedup:
score = self._evaluator(compact_model) # type: ignore
else:
self.pruner._wrap_model()
score = self._evaluator(compact_model) # type: ignore
self.pruner._unwrap_model()
else:
score = None
# clear model references
self.pruner.clear_model_references()
......@@ -107,13 +210,20 @@ class PruningScheduler(BasePruningScheduler):
finetune -> generate masks -> reset weight -> speedup -> evaluate
"""
model, masks, config_list = task.load_data()
checkpoint = deepcopy(model.state_dict())
self.pruner.reset(model, config_list)
self.pruner.load_masks(masks)
# finetune
if self.finetuner is not None and task.finetune:
self.finetuner(model)
if self.using_evaluator:
if task.finetune:
self.evaluator.bind_model(model) # type: ignore
self.evaluator.finetune()
self.evaluator.unbind_model()
else:
if self.finetuner is not None and task.finetune:
self.finetuner(model)
# pruning model
compact_model, pruner_generated_masks = self.pruner.compress()
......@@ -128,19 +238,38 @@ class PruningScheduler(BasePruningScheduler):
# speedup
if self.speedup and task.speedup:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
if self.using_evaluator:
ModelSpeedup(compact_model, self.evaluator.get_dummy_input(), pruner_generated_masks).speedup_model()
compact_model_masks = {}
else:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
# evaluate
if self.evaluator is not None and task.evaluate:
if self.speedup:
score = self.evaluator(compact_model)
if self.using_evaluator:
if task.evaluate:
self.evaluator.bind_model(compact_model) # type: ignore
# TODO: support saving customized score
if self.speedup:
score = self.evaluator.evaluate()
else:
self.pruner._wrap_model()
score = self.evaluator.evaluate()
self.pruner._unwrap_model()
score = score[0] if isinstance(score, tuple) else score
self.evaluator.unbind_model()
else:
self.pruner._wrap_model()
score = self.evaluator(compact_model)
self.pruner._unwrap_model()
score = None
else:
score = None
if self._evaluator is not None and task.evaluate:
if self.speedup:
score = self._evaluator(compact_model) # type: ignore
else:
self.pruner._wrap_model()
score = self._evaluator(compact_model) # type: ignore
self.pruner._unwrap_model()
else:
score = None
# clear model references
self.pruner.clear_model_references()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Callable, Optional, Union
from typing import Any, Dict, List, Optional, Union, overload
from torch import Tensor
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
from .basic_pruner import (
LevelPruner,
L1NormPruner,
......@@ -21,13 +21,19 @@ from .basic_pruner import (
TaylorFOWeightPruner,
ADMMPruner
)
from .basic_scheduler import PruningScheduler
from .basic_scheduler import PruningScheduler, _LEGACY_FINETUNER, _LEGACY_EVALUATOR
from .tools import (
LinearTaskGenerator,
AGPTaskGenerator,
LotteryTicketTaskGenerator,
SimulatedAnnealingTaskGenerator
)
from ..utils import (
OptimizerConstructHelper,
LightningEvaluator,
TorchEvaluator
)
from ..utils.docstring import _EVALUATOR_DOCSTRING
_logger = logging.getLogger(__name__)
......@@ -71,55 +77,67 @@ class IterativePruner(PruningScheduler):
class LinearPruner(IterativePruner):
r"""
__doc__ = r"""
Linear pruner is an iterative pruner, it will increase sparsity evenly from scratch during each iteration.
For example, the final sparsity is set as 0.5, and the iteration number is 5, then the sparsity used in each iteration are ``[0, 0.1, 0.2, 0.3, 0.4, 0.5]``.
Parameters
----------
model : Module
model
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
config_list
The origin config list provided by the user.
pruning_algorithm : str
pruning_algorithm
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
total_iteration
The total iteration number.
log_dir : str
log_dir
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
speedup : bool
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
pruning_params : Dict
pruning_params
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
Examples
--------
>>> from nni.compression.pytorch.pruning import LinearPruner
>>> config_list = [{'sparsity': 0.8, 'op_types': ['Conv2d']}]
>>> finetuner = ...
>>> pruner = LinearPruner(model, config_list, pruning_algorithm='l1', total_iteration=10, finetuner=finetuner)
>>> pruner.compress()
>>> _, model, masks, _, _ = pruner.get_best_result()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/iterative_pruning_torch.py <examples/model_compress/pruning/iterative_pruning_torch.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}):
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False,
pruning_params: Dict = {}):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: _LEGACY_FINETUNER | None = None, speedup: bool = False, dummy_input: Any | None = None,
evaluator: _LEGACY_EVALUATOR | None = None, pruning_params: Dict = {}):
...
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
*args, **kwargs):
new_api = ['evaluator', 'speedup', 'pruning_params']
new_init_kwargs = {'evaluator': None, 'speedup': False, 'pruning_params': {}}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'pruning_params']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False, 'pruning_params': {}}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
speedup = init_kwargs['speedup']
pruning_params = init_kwargs['pruning_params']
task_generator = LinearTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
......@@ -128,63 +146,80 @@ class LinearPruner(IterativePruner):
if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=False) # type: ignore
class AGPPruner(IterativePruner):
r"""
__doc__ = r"""
This is an iterative pruner, which the sparsity is increased from an initial sparsity value :math:`s_{i}` (usually 0) to a final sparsity value :math:`s_{f}` over a span of :math:`n` pruning iterations,
starting at training step :math:`t_{0}` and with pruning frequency :math:`\Delta t`:
:math:`s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n \Delta t}\right)^{3} \text { for } t \in\left\{t_{0}, t_{0}+\Delta t, \ldots, t_{0} + n \Delta t\right\}`
""" + r"""
For more details please refer to `To prune, or not to prune: exploring the efficacy of pruning for model compression <https://arxiv.org/abs/1710.01878>`__\.
Parameters
----------
model : Module
model
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
config_list
The origin config list provided by the user.
pruning_algorithm : str
pruning_algorithm
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
total_iteration
The total iteration number.
log_dir : str
log_dir
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
speedup : bool
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
pruning_params : Dict
pruning_params
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
Examples
--------
>>> from nni.compression.pytorch.pruning import AGPPruner
>>> config_list = [{'sparsity': 0.8, 'op_types': ['Conv2d']}]
>>> finetuner = ...
>>> pruner = AGPPruner(model, config_list, pruning_algorithm='l1', total_iteration=10, finetuner=finetuner)
>>> pruner.compress()
>>> _, model, masks, _, _ = pruner.get_best_result()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/iterative_pruning_torch.py <examples/model_compress/pruning/iterative_pruning_torch.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False,
pruning_params: Dict = {}):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: _LEGACY_FINETUNER | None = None, speedup: bool = False, dummy_input: Any | None = None,
evaluator: _LEGACY_EVALUATOR | None = None, pruning_params: Dict = {}):
...
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}):
*args, **kwargs):
new_api = ['evaluator', 'speedup', 'pruning_params']
new_init_kwargs = {'evaluator': None, 'speedup': False, 'pruning_params': {}}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'pruning_params']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False, 'pruning_params': {}}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
speedup = init_kwargs['speedup']
pruning_params = init_kwargs['pruning_params']
task_generator = AGPTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
......@@ -193,12 +228,16 @@ class AGPPruner(IterativePruner):
if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=False) # type: ignore
class LotteryTicketPruner(IterativePruner):
r"""
__doc__ = r"""
`The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks <https://arxiv.org/abs/1803.03635>`__\ ,
authors Jonathan Frankle and Michael Carbin,provides comprehensive measurement and analysis,
and articulate the *lottery ticket hypothesis*\ : dense, randomly-initialized, feed-forward networks contain subnetworks (*winning tickets*\ ) that
......@@ -216,55 +255,69 @@ class LotteryTicketPruner(IterativePruner):
If the configured final sparsity is P (e.g., 0.8) and there are n times iterative pruning,
each iterative pruning prunes 1-(1-P)^(1/n) of the weights that survive the previous round.
""" + r"""
Parameters
----------
model : Module
model
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
config_list
The origin config list provided by the user.
pruning_algorithm : str
pruning_algorithm
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
total_iteration
The total iteration number.
log_dir : str
log_dir
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise.
speedup : bool
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
reset_weight : bool
reset_weight
If set True, the model weight will reset to the original model weight at the end of each iteration step.
pruning_params : Dict
pruning_params
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
Examples
--------
>>> from nni.compression.pytorch.pruning import LotteryTicketPruner
>>> config_list = [{'sparsity': 0.8, 'op_types': ['Conv2d']}]
>>> finetuner = ...
>>> pruner = LotteryTicketPruner(model, config_list, pruning_algorithm='l1', total_iteration=10, finetuner=finetuner, reset_weight=True)
>>> pruner.compress()
>>> _, model, masks, _, _ = pruner.get_best_result()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/iterative_pruning_torch.py <examples/model_compress/pruning/iterative_pruning_torch.py>`
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
"""
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False,
reset_weight: bool = True, pruning_params: Dict = {}):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, reset_weight: bool = True,
finetuner: _LEGACY_FINETUNER | None = None, speedup: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: _LEGACY_EVALUATOR | None = None, reset_weight: bool = True,
pruning_params: Dict = {}):
...
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
*args, **kwargs):
new_api = ['evaluator', 'speedup', 'reset_weight', 'pruning_params']
new_init_kwargs = {'evaluator': None, 'speedup': False, 'reset_weight': True, 'pruning_params': {}}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'reset_weight', 'pruning_params']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False,
'reset_weight': True, 'pruning_params': {}}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
speedup = init_kwargs['speedup']
reset_weight = init_kwargs['reset_weight']
pruning_params = init_kwargs['pruning_params']
task_generator = LotteryTicketTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
......@@ -273,12 +326,16 @@ class LotteryTicketPruner(IterativePruner):
if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=reset_weight)
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=reset_weight)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=reset_weight) # type: ignore
class SimulatedAnnealingPruner(IterativePruner):
"""
__doc__ = r"""
We implement a guided heuristic search method, Simulated Annealing (SA) algorithm. As mentioned in the paper, this method is enhanced on guided search based on prior experience.
The enhanced SA technique is based on the observation that a DNN layer with more number of weights often has a higher degree of model compression with less impact on overall accuracy.
......@@ -294,54 +351,81 @@ class SimulatedAnnealingPruner(IterativePruner):
Parameters
----------
model : Optional[Module]
model
The origin unwrapped pytorch model to be pruned.
config_list : Optional[List[Dict]]
config_list
The origin config list provided by the user.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
start_temperature : float
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
start_temperature
Start temperature of the simulated annealing process.
stop_temperature : float
stop_temperature
Stop temperature of the simulated annealing process.
cool_down_rate : float
cool_down_rate
Cool down rate of the temperature.
perturbation_magnitude : float
perturbation_magnitude
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
pruning_algorithm : str
pruning_algorithm
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
pruning_params : Dict
pruning_params
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
log_dir : Union[str, Path]
log_dir
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speedup : bool
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
Examples
--------
>>> from nni.compression.pytorch.pruning import SimulatedAnnealingPruner
>>> model = ...
>>> config_list = [{'sparsity': 0.8, 'op_types': ['Conv2d']}]
>>> evaluator = ...
>>> finetuner = ...
>>> pruner = SimulatedAnnealingPruner(model, config_list, pruning_algorithm='l1', evaluator=evaluator, cool_down_rate=0.9, finetuner=finetuner)
>>> pruner.compress()
>>> _, model, masks, _, _ = pruner.get_best_result()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/simulated_anealing_pruning_torch.py <examples/model_compress/pruning/simulated_anealing_pruning_torch.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator,
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, pruning_algorithm: str = 'level', pruning_params: Dict = {},
log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False, speedup: bool = False):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: _LEGACY_EVALUATOR,
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, pruning_algorithm: str = 'level', pruning_params: Dict = {},
log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
finetuner: _LEGACY_FINETUNER | None = None, speedup: bool = False,
dummy_input: Optional[Tensor] = None):
...
def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
new_api = ['evaluator', 'start_temperature', 'stop_temperature', 'cool_down_rate', 'perturbation_magnitude',
'pruning_algorithm', 'pruning_params', 'log_dir', 'keep_intermediate_result', 'speedup']
new_init_kwargs = {'start_temperature': 100, 'stop_temperature': 20, 'cool_down_rate': 0.9,
'perturbation_magnitude': 0.35, 'pruning_algorithm': 'level', 'pruning_params': {},
'log_dir': '.', 'keep_intermediate_result': False, 'speedup': False}
old_api = ['evaluator', 'start_temperature', 'stop_temperature', 'cool_down_rate', 'perturbation_magnitude',
'pruning_algorithm', 'pruning_params', 'log_dir', 'keep_intermediate_result', 'finetuner',
'speedup', 'dummy_input']
old_init_kwargs = {'start_temperature': 100, 'stop_temperature': 20, 'cool_down_rate': 0.9,
'perturbation_magnitude': 0.35, 'pruning_algorithm': 'level', 'pruning_params': {},
'log_dir': '.', 'keep_intermediate_result': False, 'finetuner': None, 'speedup': False,
'dummy_input': None}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
start_temperature = init_kwargs['start_temperature']
stop_temperature = init_kwargs['stop_temperature']
cool_down_rate = init_kwargs['cool_down_rate']
perturbation_magnitude = init_kwargs['perturbation_magnitude']
pruning_algorithm = init_kwargs['pruning_algorithm']
pruning_params = init_kwargs['pruning_params']
log_dir = init_kwargs['log_dir']
keep_intermediate_result = init_kwargs['keep_intermediate_result']
speedup = init_kwargs['speedup']
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]], evaluator: Callable[[Module], float], start_temperature: float = 100,
stop_temperature: float = 20, cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35,
pruning_algorithm: str = 'level', pruning_params: Dict = {}, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None):
task_generator = SimulatedAnnealingTaskGenerator(origin_model=model,
origin_config_list=config_list,
start_temperature=start_temperature,
......@@ -351,7 +435,12 @@ class SimulatedAnnealingPruner(IterativePruner):
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) # type: ignore
pruning_params['traced_optimizer'] = \
OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) # type: ignore
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup,
dummy_input=self.dummy_input, evaluator=self._evaluator, reset_weight=False) # type: ignore
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy
import logging
from typing import Dict, List, Tuple, Callable
from typing import Dict, List, Tuple, Callable, overload
import torch
from torch import autograd, Tensor
......@@ -12,17 +14,23 @@ from torch.nn.parameter import Parameter
from torch.optim import Optimizer, Adam
from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper, LayerInfo
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema, OptimizerConstructHelper
from nni.common.serializer import Traceable
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import EvaluatorBasedPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema
from .tools.base import TrainerBasedDataCollector
from .tools.base import EvaluatorBasedDataCollector, TrainerBasedDataCollector
from .tools import (
StraightMetricsCalculator,
NormalSparsityAllocator
NormalSparsityAllocator,
StraightMetricsCalculator
)
from ..utils import (
LightningEvaluator,
TorchEvaluator
)
from ..utils.docstring import _EVALUATOR_DOCSTRING
_logger = logging.getLogger(__name__)
......@@ -47,8 +55,7 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper):
def forward(self, *inputs):
# apply mask to weight, bias
# NOTE: I don't know why training getting slower and slower if only `self.weight_mask` without `detach()`
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask.detach())) # type: ignore
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask)) # type: ignore
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.module.bias = torch.mul(self.bias, self.bias_mask) # type: ignore
return self.module(*inputs)
......@@ -77,13 +84,30 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {}
target_name = 'weight'
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight_score.data # type: ignore
data[wrapper.name] = {target_name: wrapper.weight_score.data} # type: ignore
return data
class MovementPruner(BasicPruner):
r"""
class EvaluatorBasedScoreDataCollector(EvaluatorBasedDataCollector):
"""
Collect all weight_score in wrappers as data used to calculate metrics.
"""
def collect(self) -> Dict[str, Tensor]:
assert self.compressor.bound_model is not None
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
data = {}
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target_score: Tensor = getattr(wrapper, f'{target_name}_score')
data[module_name] = {target_name: target_score.data.clone()}
return data
class MovementPruner(EvaluatorBasedPruner):
__doc__ = r"""
Movement pruner is an implementation of movement pruning.
This is a "fine-pruning" algorithm, which means the masks may change during each fine-tuning step.
Each weight element will be scored by the opposite of the sum of the product of weight and its gradient during each step.
......@@ -110,30 +134,12 @@ class MovementPruner(BasicPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer : Callable[[Module, Optimizer, Callable]
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss.
evaluator
``evaluator`` is used to replace the previous ``trainer``, ``traced_optimizer`` and ``criterion`` API.
{evaluator_docstring}
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
training_epochs : int
The total epoch number for training the model.
Make sure the total `optimizer.step()` in `training_epochs` is bigger than `cool_down_beginning_step`.
......@@ -145,33 +151,31 @@ class MovementPruner(BasicPruner):
The sparsity after each `optimizer.step()` is:
total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3).
Examples
--------
>>> import nni
>>> from nni.compression.pytorch.pruning import MovementPruner
>>> model = ...
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
>>> trainer = ...
>>> criterion = ...
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
>>> pruner = MovementPruner(model, config_list, trainer, traced_optimizer, criterion, 10, 3000, 27000)
>>> masked_model, masks = pruner.compress()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/movement_pruning_glue.py <examples/model_compress/pruning/movement_pruning_glue.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, training_epochs: int,
warm_up_step: int, cool_down_beginning_step: int):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, warm_up_step: int,
cool_down_beginning_step: int):
self.trainer = trainer
if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer
else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion
self.training_epochs = training_epochs
self.warm_up_step = warm_up_step
self.cool_down_beginning_step = cool_down_beginning_step
...
def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
# TODO: remove in nni v3.0. Fake overload.
new_api = ['evaluator', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step']
old_api = ['trainer', 'traced_optimizer', 'criterion', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step']
init_kwargs = self._init_evaluator(model, new_api, old_api, {}, args, kwargs)
self.training_epochs: int = init_kwargs['training_epochs']
self.warm_up_step: int = init_kwargs['warm_up_step']
self.cool_down_beginning_step: int = init_kwargs['cool_down_beginning_step']
assert self.warm_up_step < self.cool_down_beginning_step, '`warm_up_step` should smaller than `cool_down_beginning_step`'
super().__init__(model, config_list)
......@@ -184,14 +188,16 @@ class MovementPruner(BasicPruner):
if self.warm_up_step < current_step <= self.cool_down_beginning_step:
wrapper_dict = self.get_modules_wrapper()
for config in self.config_list:
current_sparsity = config['total_sparsity'] * (1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3)
scale = 1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3
current_sparsity = config['total_sparsity'] * scale
for op_name in config['op_names']:
wrapper_dict[op_name].config['total_sparsity'] = current_sparsity
wrapper = wrapper_dict[op_name]
wrapper.config['total_sparsity'] = current_sparsity
def reset_tools(self):
if self.metrics_calculator is None:
if not hasattr(self, 'metrics_calculator'):
self.metrics_calculator = StraightMetricsCalculator()
if self.sparsity_allocator is None:
if not hasattr(self, 'sparsity_allocator'):
self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False)
# use Adam to update the weight_score
......@@ -208,16 +214,30 @@ class MovementPruner(BasicPruner):
if self.step_counter > self.warm_up_step:
self.cubic_schedule(self.step_counter)
data = {}
target_name = 'weight'
for wrapper in self.get_modules_wrapper().values():
data[wrapper.name] = wrapper.weight_score.data
data[wrapper.name] = {target_name: wrapper.weight_score.data}
metrics = self.metrics_calculator.calculate_metrics(data) # type: ignore
masks = self.sparsity_allocator.generate_sparsity(metrics) # type: ignore
self.load_masks(masks)
if self.data_collector is None:
self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion, self.training_epochs, opt_after_tasks=[_optimizer_patch])
if self.using_evaluator:
# TODO: move to other place in nni v3.0
self.evaluator.unbind_model()
self.evaluator.bind_model(self.bound_model, self.get_origin2wrapped_parameter_name_map()) # type: ignore
if not hasattr(self, 'data_collector'):
self.data_collector = EvaluatorBasedScoreDataCollector(self, self.evaluator,
after_opt_step_tasks=[_optimizer_patch],
max_epochs=self.training_epochs)
else:
self.data_collector.reset(after_opt_step_tasks=[_optimizer_patch])
else:
self.data_collector.reset()
if not hasattr(self, 'data_collector'):
self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper,
self.criterion, self.training_epochs,
opt_after_tasks=[_optimizer_patch])
else:
self.data_collector.reset()
def _wrap_modules(self, layer: LayerInfo, config: Dict):
"""
......@@ -243,7 +263,6 @@ class MovementPruner(BasicPruner):
for wrapper in self.get_modules_wrapper().values():
wrapper.config['total_sparsity'] = 0
result = super().compress()
# del weight_score
for wrapper in self.get_modules_wrapper().values():
wrapper.weight_score = None
if self.using_evaluator:
self.evaluator.unbind_model()
return result
......@@ -8,6 +8,12 @@ from .base import (
SparsityAllocator,
TaskGenerator
)
from .data_collector import (
TargetDataCollector,
EvaluatorBasedTargetDataCollector,
EvaluatorBasedHookDataCollector
)
# TODO: remove in nni v3.0.
from .data_collector import (
WeightDataCollector,
WeightTrainerBasedDataCollector,
......@@ -16,7 +22,7 @@ from .data_collector import (
from .metrics_calculator import (
StraightMetricsCalculator,
NormMetricsCalculator,
MultiDataNormMetricsCalculator,
HookDataNormMetricsCalculator,
DistMetricsCalculator,
APoZRankMetricsCalculator,
MeanRankMetricsCalculator
......
......@@ -6,7 +6,7 @@ from datetime import datetime
import logging
from pathlib import Path
import types
from typing import List, Dict, Tuple, Optional, Callable, Union
from typing import List, Dict, Literal, Tuple, Optional, Callable, Union
import json_tricks
import torch
......@@ -15,7 +15,7 @@ from torch.nn import Module
from torch.optim import Optimizer
from ...base import Pruner, LayerInfo, Task, TaskResult
from ...utils import OptimizerConstructHelper, Scaling
from ...utils import Evaluator, Hook, OptimizerConstructHelper, Scaling
_logger = logging.getLogger(__name__)
......@@ -45,7 +45,7 @@ class DataCollector:
def __init__(self, compressor: Pruner):
self.compressor = compressor
def reset(self):
def reset(self, *args, **kwargs):
"""
Reset the `DataCollector`.
"""
......@@ -63,9 +63,12 @@ class DataCollector:
raise NotImplementedError()
# TODO: remove in nni v3.0.
COLLECTOR_TYPE = Union[Callable[[List, Tensor], Callable[[Tensor], None]], Callable[[List], Callable[[Module, Tensor, Tensor], None]]]
class HookCollectorInfo:
def __init__(self, targets: Union[Dict[str, Tensor], List[LayerInfo]], hook_type: str,
collector: Union[Callable[[List, Tensor], Callable[[Tensor], None]], Callable[[List], Callable[[Module, Tensor, Tensor], None]]]):
collector: COLLECTOR_TYPE):
"""
This class used to aggregate the information of what kind of hook is placed on which layers.
......@@ -76,23 +79,24 @@ class HookCollectorInfo:
hook_type
'forward' or 'backward'.
collector
A hook function generator, the input is a buffer (empty list) or a buffer (empty list) and tensor, the output is a hook function.
The buffer is used to store the data wanted to hook.
A hook function generator, the input is a buffer (empty list) or a buffer (empty list) and tensor,
the output is a hook function. The buffer is used to store the data wanted to hook.
"""
self.targets = targets
self.hook_type = hook_type
self.collector = collector
# TODO: remove in nni v3.0.
class TrainerBasedDataCollector(DataCollector):
"""
This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks.
"""
def __init__(self, compressor: Pruner, trainer: Callable[[Module, Optimizer, Callable], None], optimizer_helper: OptimizerConstructHelper,
criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
opt_before_tasks: List = [], opt_after_tasks: List = [],
collector_infos: List[HookCollectorInfo] = [], criterion_patch: Optional[Callable[[Callable], Callable]] = None):
def __init__(self, compressor: Pruner, trainer: Callable[[Module, Optimizer, Callable], None],
optimizer_helper: OptimizerConstructHelper, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
opt_before_tasks: List = [], opt_after_tasks: List = [], collector_infos: List[HookCollectorInfo] = [],
criterion_patch: Optional[Callable[[Callable], Callable]] = None):
"""
Parameters
----------
......@@ -252,6 +256,47 @@ class TrainerBasedDataCollector(DataCollector):
self._remove_hook(hook_id)
class EvaluatorBasedDataCollector(DataCollector):
"""
This data collector is the base class for the data collectors that want to use ``Evaluator`` to train or inference.
Three main usages are supported in this data collector:
1. Doing something before ``optimzer.step()`` and after ``optimzer.step()``. ``before_opt_step_tasks`` is a list of task functions
that will execute before ``optimzer.step()``. ``after_opt_step_tasks`` is a list of task functions that will execute after
``optimzer.step()``. All the task functions in the list should not have input arguments, function return value is allowed,
but ``Evaluator`` will not catch it.
2. Patch or modify the training loss. ``loss_patch`` is a function with input is the original loss and the output is the modified loss.
3. Add hooks on ``torch.nn.Module`` or ``Parameter`` or ``Buffer``. Three kinds of hook are supported, ``TensorHook``, ``ForwardHook``
and ``BackwardHook``. For initializing a ``Hook``, a hook function factory is needed, the factory function's input is an empty list,
and the output is a hook function defined by Pytorch.
Please refer `register_hook <https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html>`_,
`register_forward_hook <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook>`_,
`register_backward_hook <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_backward_hook>`_.
"""
def __init__(self, compressor: Pruner, evaluator: Evaluator, before_opt_step_tasks: List[Callable] | None = None,
after_opt_step_tasks: List[Callable] | None = None, loss_patch: Callable[[Tensor], Tensor] | None = None,
hooks: Dict[str, Dict[str, Hook]] | None = None, max_steps: int | None = None, max_epochs: int | None = None):
super().__init__(compressor)
self.evaluator = evaluator
self.max_steps = max_steps
self.max_epochs = max_epochs
self.reset(before_opt_step_tasks, after_opt_step_tasks, loss_patch, hooks)
def reset(self, before_opt_step_tasks: List[Callable] | None = None, after_opt_step_tasks: List[Callable] | None = None,
loss_patch: Callable[[Tensor], Tensor] | None = None, hooks: Dict[str, Dict[str, Hook]] | None = None):
if before_opt_step_tasks or after_opt_step_tasks:
before_opt_step_tasks = before_opt_step_tasks if before_opt_step_tasks else []
after_opt_step_tasks = after_opt_step_tasks if after_opt_step_tasks else []
self.evaluator.patch_optimizer_step(before_opt_step_tasks, after_opt_step_tasks)
if loss_patch:
self.evaluator.patch_loss(loss_patch)
if hooks:
self._hooks = hooks
hook_list = [hook for _ in hooks.values() for hook in _.values()]
self.evaluator.register_hooks(hook_list)
class MetricsCalculator:
"""
An abstract class for calculate a kind of metrics of the given data.
......@@ -260,7 +305,8 @@ class MetricsCalculator:
----------
scalers
Scaler is used to scale the metrics' size. It scaling metric to the same size as the shrinked mask in the sparsity allocator.
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`.
If you want to use different scalers for different pruning targets in different modules,
please use a dict `{module_name: {target_name: scaler}}`.
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
......@@ -268,7 +314,8 @@ class MetricsCalculator:
"""
def __init__(self, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None):
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers \
if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
def _get_scaler(self, module_name: str, target_name: str) -> Scaling:
scaler = _get_scaler(self.scalers, module_name, target_name)
......@@ -301,7 +348,8 @@ class SparsityAllocator:
scalers
Scaler is used to scale the masks' size. It shrinks the mask of the same size as the pruning target to the same size as the metric,
or expands the mask of the same size as the metric to the same size as the pruning target.
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`.
If you want to use different scalers for different pruning targets in different modules,
please use a dict `{module_name: {target_name: scaler}}`.
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
......@@ -313,7 +361,8 @@ class SparsityAllocator:
def __init__(self, pruner: Pruner, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None, continuous_mask: bool = True):
self.pruner = pruner
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers \
if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
self.continuous_mask = continuous_mask
def _get_scaler(self, module_name: str, target_name: str) -> Scaling | None:
......@@ -335,25 +384,39 @@ class SparsityAllocator:
mask = (scaler.shrink(mask) != 0).type_as(mask)
return mask
def _continuous_mask(self, new_masks: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
def _mask_metric(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
# Set the already masked part in the metric to the minimum value.
target_name = 'weight'
for module_name, targets_metric in metrics.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
shrinked_target_mask = self._shrink_mask(module_name, target_name, old_target_mask)
# make sure the masked position has the minimum metric
targets_metric[target_name] = targets_metric[target_name].to(shrinked_target_mask.device)
min_value = targets_metric[target_name].min() - 1
targets_metric[target_name] = torch.where(shrinked_target_mask != 0, targets_metric[target_name], min_value)
return metrics
def _continuous_mask(self, new_masks: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
# Set the already masked part to zero in the new_masks.
target_name = 'weight'
for module_name, target_mask in new_masks.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
old_target_mask = getattr(wrapper, f'{target_name}_mask', None)
old_target_mask: Tensor | None = getattr(wrapper, f'{target_name}_mask', None)
if old_target_mask is not None:
new_masks[module_name][target_name] = torch.min(target_mask[target_name], old_target_mask)
new_masks[module_name][target_name] = torch.min(target_mask[target_name],
old_target_mask.to(target_mask[target_name].device))
return new_masks
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
"""
Generate masks for metrics-dependent targets.
Parameters
----------
metrics
The format is {module_name: weight_metric}.
The metric of `weight` usually has the same size with shrinked mask.
The format is {module_name: {target_name: target_metric}}.
The metric of usually has the same size with shrinked mask.
Return
------
......@@ -384,7 +447,7 @@ class SparsityAllocator:
reduce_dims = [reduce_dim for reduce_dim in range(1, len(weight_mask.shape))]
# count unmasked number of values on dim 0 (output channel) of weight
unmasked_num_on_dim0 = weight_mask.sum(reduce_dims) if reduce_dims else weight_mask
module_masks['bias'] = (unmasked_num_on_dim0 != 0).type_as(old_bias_mask)
module_masks['bias'] = (unmasked_num_on_dim0 != 0).type_as(weight_mask)
return masks
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
......@@ -401,6 +464,8 @@ class SparsityAllocator:
Dict[str, Dict[str, Tensor]]
The masks format is {module_name: {target_name: mask}}.
"""
if self.continuous_mask:
metrics = self._mask_metric(metrics)
masks = self.common_target_masks_generation(metrics)
masks = self.special_target_masks_generation(masks)
if self.continuous_mask:
......@@ -425,11 +490,22 @@ class TaskGenerator:
The log directory use to saving the task generator log.
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
best_result_mode
The way to decide which one is the best result. Three modes are supported.
If the task results don't contain scores (task_result.score is None), it will fall back to ``latest``.
1. latest: The newest received result is the best result.
2. maximize: The one with largest task result score is the best result.
3. minimize: The one with smallest task result score is the best result.
"""
def __init__(self, origin_model: Optional[Module], origin_masks: Optional[Dict[str, Dict[str, Tensor]]] = {},
origin_config_list: Optional[List[Dict]] = [], log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
origin_config_list: Optional[List[Dict]] = [], log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
best_result_mode: Literal['latest', 'maximize', 'minimize'] = 'maximize'):
self._log_dir = log_dir
self._keep_intermediate_result = keep_intermediate_result
assert best_result_mode in ['latest', 'maximize', 'minimize'], f'Unsupported best_result_mode value: {best_result_mode}'
self._best_result_mode = best_result_mode
if origin_model is not None and origin_config_list is not None and origin_masks is not None:
self.reset(origin_model, origin_config_list, origin_masks)
......@@ -472,13 +548,24 @@ class TaskGenerator:
json_tricks.dump(config_list, f, indent=4)
def update_best_result(self, task_result: TaskResult):
score = task_result.score
task_id = task_result.task_id
task = self._tasks[task_id]
task.score = score
if self._best_score is None or (score is not None and score > self._best_score):
self._best_score = score
self._best_task_id = task_id
save_as_best_result = False
task = self._tasks[task_result.task_id]
task.score = task_result.score
if self._best_result_mode == 'latest':
self._best_task_id, save_as_best_result = task_result.task_id, True
if self._best_result_mode == 'maximize':
if self._best_score is None or (task.score is not None and task.score > self._best_score):
self._best_score = task.score
self._best_task_id, save_as_best_result = task_result.task_id, True
if self._best_result_mode == 'minimize':
if self._best_score is None or (task.score is not None and task.score < self._best_score):
self._best_score = task.score
self._best_task_id, save_as_best_result = task_result.task_id, True
if save_as_best_result:
with Path(task.config_list_path).open('r') as fr:
best_config_list = json_tricks.load(fr)
self._save_data('best_result', task_result.compact_model, task_result.compact_model_masks, best_config_list)
......
......@@ -6,13 +6,16 @@ from typing import Dict, List
from torch import Tensor
from .base import DataCollector, TrainerBasedDataCollector
from .base import DataCollector, EvaluatorBasedDataCollector
from .base import TrainerBasedDataCollector
_logger = logging.getLogger(__name__)
__all__ = ['WeightDataCollector', 'WeightTrainerBasedDataCollector', 'SingleHookTrainerBasedDataCollector']
__all__ = ['TargetDataCollector', 'EvaluatorBasedTargetDataCollector', 'EvaluatorBasedHookDataCollector',
'WeightDataCollector', 'WeightTrainerBasedDataCollector', 'SingleHookTrainerBasedDataCollector'] # TODO: remove in nni v3.0.
# TODO: remove in nni v3.0.
class WeightDataCollector(DataCollector):
"""
Collect all wrapper weights.
......@@ -21,40 +24,102 @@ class WeightDataCollector(DataCollector):
def reset(self):
pass
def collect(self) -> Dict[str, Tensor]:
def collect(self) -> Dict[str, Dict[str, Tensor]]:
data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight.data
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
# TODO: remove in nni v3.0.
class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
"""
Collect all wrapper weights after training or inference.
"""
def collect(self) -> Dict[str, Tensor]:
def collect(self) -> Dict[str, Dict[str, Tensor]]:
assert self.compressor.bound_model is not None
for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight.data
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
# TODO: remove in nni v3.0.
class SingleHookTrainerBasedDataCollector(TrainerBasedDataCollector):
"""
Add hooks and collect data during training or inference.
Single means each wrapper only has one hook to collect data.
"""
def collect(self) -> Dict[str, List[Tensor]]:
def collect(self) -> Dict[str, Dict[str, List[Tensor]]]:
assert self.compressor.bound_model is not None
for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {}
[data.update(buffer_dict) for _, buffer_dict in self._hook_buffer.items()]
target_name = 'weight'
for _, buffer_dict in self._hook_buffer.items():
for module_name, target_data in buffer_dict.items():
data[module_name] = {target_name: target_data}
return data
class TargetDataCollector(DataCollector):
"""
Collect all wrapper targets.
"""
def reset(self):
# No need to reset anything in this data collector.
pass
def collect(self) -> Dict[str, Dict[str, Tensor]]:
data = {}
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
class EvaluatorBasedTargetDataCollector(EvaluatorBasedDataCollector):
"""
Collect all wrapper pruning target after training or inference.
"""
def collect(self) -> Dict[str, Dict[str, Tensor]]:
assert self.compressor.bound_model is not None
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
data = {}
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
class EvaluatorBasedHookDataCollector(EvaluatorBasedDataCollector):
"""
Add hooks and collect data during training or inference.
NOTE: Only support one target has one hook right now.
"""
def collect(self) -> Dict[str, Dict[str, List]]:
assert self.compressor.bound_model is not None
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
data = {}
for module_name, hooks in self._hooks.items():
data[module_name] = {}
for target_name, hook in hooks.items():
data[module_name][target_name] = hook.buffer
return data
......@@ -11,7 +11,7 @@ from torch import Tensor
from .base import MetricsCalculator
from ...utils import Scaling
__all__ = ['NormMetricsCalculator', 'MultiDataNormMetricsCalculator', 'DistMetricsCalculator',
__all__ = ['NormMetricsCalculator', 'HookDataNormMetricsCalculator', 'DistMetricsCalculator',
'APoZRankMetricsCalculator', 'MeanRankMetricsCalculator', 'StraightMetricsCalculator']
......@@ -19,11 +19,12 @@ class StraightMetricsCalculator(MetricsCalculator):
"""
This metrics calculator directly returns a copy of data as metrics.
"""
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
metrics = {}
for name, tensor in data.items():
# use inplace detach `detach_` here to avoid creating a new tensor
metrics[name] = tensor.clone().detach_()
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
metrics[module_name][target_name] = target_data.clone().detach()
return metrics
......@@ -44,27 +45,32 @@ class NormMetricsCalculator(MetricsCalculator):
super().__init__(scalers=scalers)
self.p = p if p is not None else 'fro'
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor:
return t.norm(p=self.p, dim=-1) # type: ignore
metrics = {}
target_name = 'weight'
for module_name, target_data in data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics
class MultiDataNormMetricsCalculator(NormMetricsCalculator):
class HookDataNormMetricsCalculator(NormMetricsCalculator):
"""
The data value format is a two-element list [batch_number, cumulative_data].
The hook data value format is a two-element list [batch_number, cumulative_data].
Directly use the cumulative_data as new_data to calculate norm metric.
TaylorFO pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
new_data = {name: buffer[1] for name, buffer in data.items()}
def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
new_data = {}
for module_name, targets_data in data.items():
new_data[module_name] = {}
for target_name, (_, target_data) in targets_data.items():
new_data[module_name][target_name] = target_data
return super().calculate_metrics(new_data)
......@@ -85,7 +91,7 @@ class DistMetricsCalculator(MetricsCalculator):
super().__init__(scalers=scalers)
self.p = p if p is not None else 'fro'
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor:
reshape_data = t.reshape(-1, t.shape[-1])
metric = torch.zeros(reshape_data.shape[0], device=reshape_data.device)
......@@ -94,10 +100,11 @@ class DistMetricsCalculator(MetricsCalculator):
return metric.reshape(t.shape[:-1])
metrics = {}
target_name = 'weight'
for module_name, target_data in data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics
......@@ -108,16 +115,18 @@ class APoZRankMetricsCalculator(MetricsCalculator):
Note that the metric we return is (1 - apoz), because we assume a higher metric value has higher importance.
APoZRank pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor:
return 1 - t.mean(dim=-1)
metrics = {}
target_name = 'weight'
for module_name, target_data in data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics
......@@ -127,14 +136,15 @@ class MeanRankMetricsCalculator(MetricsCalculator):
This metric simply calculate the average on `self.dim`, then divide by the batch_number.
MeanRank pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor:
return t.mean(dim=-1)
metrics = {}
target_name = 'weight'
for module_name, target_data in data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics
......@@ -17,7 +17,8 @@ _logger = logging.getLogger(__name__)
class AMCEnv:
def __init__(self, model: Module, config_list: List[Dict], dummy_input: Tensor, total_sparsity: float, max_sparsity_per_layer: Dict[str, float], target: str = 'flops'):
def __init__(self, model: Module, config_list: List[Dict], dummy_input: Tensor, total_sparsity: float,
max_sparsity_per_layer: Dict[str, float], target: str = 'flops'):
pruning_op_names = []
[pruning_op_names.extend(config['op_names']) for config in config_list_canonical(model, config_list)]
self.pruning_ops = OrderedDict()
......@@ -26,7 +27,10 @@ class AMCEnv:
if name in pruning_op_names:
op_type = type(layer).__name__
stride = np.power(np.prod(layer.stride), 1 / len(layer.stride)) if hasattr(layer, 'stride') else 0 # type: ignore
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) if hasattr(layer, 'kernel_size') else 1 # type: ignore
if hasattr(layer, 'kernel_size'):
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) # type: ignore
else:
kernel_size = 1
self.pruning_ops[name] = (i, op_type, stride, kernel_size)
self.pruning_types.append(op_type)
self.pruning_types = list(set(self.pruning_types))
......@@ -60,15 +64,18 @@ class AMCEnv:
total_current_target = sum([current_statistics[name][self.target] for name in self.pruning_op_names])
previous_pruning_target = self.under_pruning_target - total_current_target
max_rest_pruning_target = sum([current_statistics[name][self.target] * self.max_sparsity_per_layer[name] for name in self.pruning_op_names[index + 1:]])
max_rest_pruning_target = sum([current_statistics[name][self.target] * self.max_sparsity_per_layer[name] \
for name in self.pruning_op_names[index + 1:]])
min_current_pruning_target = self.excepted_pruning_target - previous_pruning_target - max_rest_pruning_target
max_current_pruning_target_1 = self.origin_statistics[op_name][self.target] * self.max_sparsity_per_layer[op_name] - (self.origin_statistics[op_name][self.target] - current_statistics[op_name][self.target])
max_current_pruning_target_1 = self.origin_statistics[op_name][self.target] * self.max_sparsity_per_layer[op_name] - \
(self.origin_statistics[op_name][self.target] - current_statistics[op_name][self.target])
max_current_pruning_target_2 = self.excepted_pruning_target - previous_pruning_target
max_current_pruning_target = min(max_current_pruning_target_1, max_current_pruning_target_2)
min_action = min_current_pruning_target / current_statistics[op_name][self.target]
max_action = max_current_pruning_target / current_statistics[op_name][self.target]
if min_action > self.max_sparsity_per_layer[op_name]:
_logger.warning('[%s] min action > max sparsity per layer: %f > %f', op_name, min_action, self.max_sparsity_per_layer[op_name])
warn_msg = f'[{op_name}] min action > max sparsity per layer: {min_action} > {self.max_sparsity_per_layer[op_name]}'
_logger.warning(warn_msg)
action = max(0., min(max_action, max(min_action, action)))
self.current_op_name = op_name
......
......@@ -4,7 +4,7 @@
from __future__ import annotations
import itertools
from typing import Any, Dict, List, Union
from typing import Any, Dict
import numpy as np
import torch
......@@ -23,22 +23,22 @@ class NormalSparsityAllocator(SparsityAllocator):
This allocator directly masks the locations of each pruning target with lower metric values.
"""
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
masks = {}
# TODO: Support more target type in wrapper & config list refactor
target_name = 'weight'
for module_name, target_metric in metrics.items():
for module_name, targets_metric in metrics.items():
masks[module_name] = {}
wrapper = self.pruner.get_modules_wrapper()[module_name]
sparsity_rate = wrapper.config['total_sparsity']
prune_num = int(sparsity_rate * target_metric.numel())
if prune_num != 0:
threshold = torch.topk(target_metric.reshape(-1), prune_num, largest=False)[0].max()
shrinked_mask = torch.gt(target_metric, threshold).type_as(target_metric)
else:
# target_metric should have the same size as shrinked_mask
shrinked_mask = torch.ones_like(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
for target_name, target_metric in targets_metric.items():
sparsity_rate = wrapper.config['total_sparsity']
prune_num = int(sparsity_rate * target_metric.numel())
if prune_num != 0:
threshold = torch.topk(target_metric.reshape(-1), prune_num, largest=False)[0].max()
shrinked_mask = torch.gt(target_metric, threshold).type_as(target_metric)
else:
# target_metric should have the same size as shrinked_mask
shrinked_mask = torch.ones_like(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks
......@@ -46,7 +46,7 @@ class BankSparsityAllocator(SparsityAllocator):
"""
In bank pruner, all values in weight are divided into different sub blocks each shape
aligned with balance_gran. Each sub block has the same sparsity which equal to the overall sparsity.
This allocator pruned the weight in the granularity of block.
This allocator pruned the weight in the granularity of block.
"""
def __init__(self, pruner: Pruner, balance_gran: list):
......@@ -56,101 +56,108 @@ class BankSparsityAllocator(SparsityAllocator):
assert isinstance(gran, int) and gran > 0, 'All values in list balance_gran \
should be type int and bigger than zero'
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
masks = {}
# TODO: Support more target type in wrapper & config list refactor
target_name = 'weight'
for module_name, target_metric in metrics.items():
for module_name, targets_metric in metrics.items():
masks[module_name] = {}
wrapper = self.pruner.get_modules_wrapper()[module_name]
sparsity_rate = wrapper.config['total_sparsity']
n_dim = len(target_metric.shape)
assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric'
# make up for balance_gran
balance_gran = [1] * (n_dim - len(self.balance_gran)) + self.balance_gran
for i, j in zip(target_metric.shape, balance_gran):
assert i % j == 0, 'Length of {} {} is not aligned with balance granularity'.format(module_name, target_name)
# FIXME: The following code need refactor, do it after scaling refactor is done.
shrinked_mask = torch.ones(target_metric.shape).type_as(target_metric)
loop_iters = [range(int(i / j)) for i, j in zip(target_metric.shape, balance_gran)]
for iter_params in itertools.product(*loop_iters):
index_str_list = [f"{iter_param * gran}:{(iter_param+1) * gran}"\
for iter_param, gran in zip(iter_params, balance_gran)]
index_str = ",".join(index_str_list)
sub_metric_str = "target_metric[{}]".format(index_str)
sub_mask_str = "shrinked_mask[{}] = mask_bank".format(index_str)
metric_bank: Tensor = eval(sub_metric_str)
prune_num = int(sparsity_rate * metric_bank.numel())
# mask_bank will be used in exec(sub_mask_str)
if prune_num != 0:
threshold = torch.topk(metric_bank.reshape(-1), prune_num, largest=False)[0].max()
mask_bank = torch.gt(metric_bank, threshold).type_as(metric_bank)
else:
mask_bank = torch.ones_like(metric_bank)
exec(sub_mask_str)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
for target_name, target_metric in targets_metric.items():
sparsity_rate = wrapper.config['total_sparsity']
n_dim = len(target_metric.shape)
assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric'
# make up for balance_gran
balance_gran = [1] * (n_dim - len(self.balance_gran)) + self.balance_gran
for i, j in zip(target_metric.shape, balance_gran):
assert i % j == 0, 'Length of {} {} is not aligned with balance granularity'.format(module_name, target_name)
# FIXME: The following code need refactor, do it after scaling refactor is done.
shrinked_mask = torch.ones(target_metric.shape).type_as(target_metric)
loop_iters = [range(int(i / j)) for i, j in zip(target_metric.shape, balance_gran)]
for iter_params in itertools.product(*loop_iters):
index_str_list = [f"{iter_param * gran}:{(iter_param+1) * gran}"\
for iter_param, gran in zip(iter_params, balance_gran)]
index_str = ",".join(index_str_list)
sub_metric_str = "target_metric[{}]".format(index_str)
sub_mask_str = "shrinked_mask[{}] = mask_bank".format(index_str)
metric_bank: Tensor = eval(sub_metric_str)
prune_num = int(sparsity_rate * metric_bank.numel())
# mask_bank will be used in exec(sub_mask_str)
if prune_num != 0:
threshold = torch.topk(metric_bank.reshape(-1), prune_num, largest=False)[0].max()
mask_bank = torch.gt(metric_bank, threshold).type_as(metric_bank) # type: ignore
else:
mask_bank = torch.ones_like(metric_bank) # type: ignore
mask_bank = mask_bank # `type: ignore` is useless for unused-variable error, add this line to workaround
exec(sub_mask_str)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks
class GlobalSparsityAllocator(SparsityAllocator):
"""
This allocator sorts all metrics as a whole, mask the locations of pruning target with lower metric value.
By default, this allocator will prevent each module from being over-pruned with upper sparsity 0.99.
"""
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
masks = {}
if not metrics:
return masks
# TODO: support more target type in wrapper & config list refactor
target_name = 'weight'
# validate all wrapper setting the same sparsity
# validate all wrapper setting have the same sparsity
# TODO: move validation logic to pruner
global_sparsity_rate = self.pruner.get_modules_wrapper()[list(metrics.keys())[0]].config['total_sparsity']
for module_name, target_metric in metrics.items():
for module_name in metrics.keys():
wrapper = self.pruner.get_modules_wrapper()[module_name]
assert global_sparsity_rate == wrapper.config['total_sparsity']
# find the largest metric value among all metrics
max_metric_value = list(metrics.values())[0].max()
for module_name, target_metric in metrics.items():
max_metric_value = max_metric_value if max_metric_value >= target_metric.max() else target_metric.max()
max_metric_value = list(list(metrics.values())[0].values())[0].max()
for targets_metric in metrics.values():
for target_metric in targets_metric.values():
max_metric_value = max_metric_value if max_metric_value >= target_metric.max() else target_metric.max()
# prevent each module from being over-pruned, prevent ratio is 'max_sparsity_per_layer'
for module_name, target_metric in metrics.items():
for module_name, targets_metric in metrics.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
max_sparsity = wrapper.config.get('max_sparsity_per_layer', {}).get(module_name, 0.99)
assert 0 <= max_sparsity <= 1
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
expand_times = old_target_mask.numel() // target_metric.numel()
max_pruning_numel = int(max_sparsity * target_metric.numel()) * expand_times
threshold = torch.topk(target_metric.reshape(-1), max_pruning_numel, largest=False)[0].max()
metrics[module_name] = torch.where(target_metric <= threshold, target_metric, max_metric_value)
for target_name, target_metric in targets_metric.items():
max_sparsity = wrapper.config.get('max_sparsity_per_layer', {}).get(module_name, 0.99)
assert 0 <= max_sparsity <= 1
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
expand_times = old_target_mask.numel() // target_metric.numel()
max_pruning_numel = int(max_sparsity * target_metric.numel()) * expand_times
threshold = torch.topk(target_metric.reshape(-1), max_pruning_numel, largest=False)[0].max()
metrics[module_name][target_name] = torch.where(target_metric <= threshold, target_metric, max_metric_value)
# build the global_matric & calculate global threshold
metric_list = []
for module_name, target_metric in metrics.items():
for module_name, targets_metric in metrics.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
expand_times = old_target_mask.numel() // target_metric.numel()
metric_list.append(target_metric.reshape(-1).unsqueeze(0).expand(expand_times, -1).reshape(-1))
for target_name, target_metric in targets_metric.items():
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
expand_times = old_target_mask.numel() // target_metric.numel()
metric_list.append(target_metric.reshape(-1).repeat_interleave(expand_times))
global_metric = torch.cat(metric_list)
max_pruning_num = int((global_metric != max_metric_value).sum().item())
total_pruning_num = min(int(global_sparsity_rate * global_metric.numel()), max_pruning_num)
global_threshold = torch.topk(global_metric.reshape(-1), total_pruning_num, largest=False)[0].max()
# generate masks for each target
for module_name, target_metric in metrics.items():
for module_name, targets_metric in metrics.items():
masks[module_name] = {}
wrapper = self.pruner.get_modules_wrapper()[module_name]
shrinked_mask = torch.gt(target_metric, global_threshold).type_as(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
for target_name, target_metric in targets_metric.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
shrinked_mask = torch.gt(target_metric, global_threshold).type_as(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks
class DependencyAwareAllocator(NormalSparsityAllocator):
# TODO: This allocator will trace the model, means the model will be inference during initialization,
# sometime we may not aware of this inference and it may lead to some error.
class DependencyAwareAllocator(SparsityAllocator):
"""
An specific allocator for Conv2d & Linear module with dependency-aware.
It will generate a public mask for the modules that have dependencies,
......@@ -170,52 +177,79 @@ class DependencyAwareAllocator(NormalSparsityAllocator):
# group dependency format: {module_name: group_num}
self.pruner._unwrap_model()
graph = TorchModuleGraph(model=self.pruner.bound_model, dummy_input=dummy_input)
channel_dependency = ChannelDependency(model=self.pruner.bound_model, dummy_input=dummy_input, traced_model=graph.trace).dependency_sets
group_dependency = GroupDependency(model=self.pruner.bound_model, dummy_input=dummy_input, traced_model=graph.trace).dependency_sets
channel_dependency = ChannelDependency(model=self.pruner.bound_model, dummy_input=dummy_input,
traced_model=graph.trace).dependency_sets
group_dependency = GroupDependency(model=self.pruner.bound_model, dummy_input=dummy_input,
traced_model=graph.trace).dependency_sets
self.pruner._wrap_model()
return channel_dependency, group_dependency
def _metric_fuse(self, metrics: Union[Dict[str, Tensor], List[Tensor]]) -> Tensor:
def _metric_fuse(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Tensor]:
# Sum all metric value in the same position.
metrics = list(metrics.values()) if isinstance(metrics, dict) else metrics
assert all(metrics[0].size() == metric.size() for metric in metrics), 'Metrics size do not match.'
fused_metric = torch.zeros_like(metrics[0])
for metric in metrics:
fused_metric += metric
return fused_metric
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
fused_metrics = {}
for targets_metric in metrics.values():
for target_name, target_metric in targets_metric.items():
if target_name in fused_metrics:
fused_metrics[target_name] += target_metric
else:
fused_metrics[target_name] = target_metric
return fused_metrics
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
# placeholder, here we need more discussion about dependence sparsity, Plan A or Plan B.
masks = {}
# generate public part for modules that have dependencies
for module_names in self.channel_dependency:
sub_metrics = {module_name: metrics[module_name] for module_name in module_names if module_name in metrics}
if not sub_metrics:
continue
fused_metric = self._metric_fuse(sub_metrics)
sparsity_rates = {module_name: self.pruner.get_modules_wrapper()[module_name].config['total_sparsity'] for module_name in sub_metrics.keys()}
min_sparsity_rate = min(sparsity_rates.values())
group_nums = [self.group_dependency.get(module_name, 1) for module_name in sub_metrics.keys()]
max_group_nums = int(np.lcm.reduce(group_nums))
pruned_numel_per_group = int(fused_metric.numel() // max_group_nums * min_sparsity_rate)
group_step = fused_metric.shape[0] // max_group_nums
# get the public part of the mask of the module with dependencies
sub_masks = []
for gid in range(max_group_nums):
_start = gid * group_step
_end = (gid + 1) * group_step
if pruned_numel_per_group > 0:
threshold = torch.topk(fused_metric[_start: _end].reshape(-1), pruned_numel_per_group, largest=False)[0].max()
sub_mask = torch.gt(fused_metric[_start:_end], threshold).type_as(fused_metric)
fused_metrics = self._metric_fuse(sub_metrics)
for target_name, fused_metric in fused_metrics.items():
sparsity_rates = {module_name: self.pruner.get_modules_wrapper()[module_name].config['total_sparsity'] \
for module_name in sub_metrics.keys()}
min_sparsity_rate = min(sparsity_rates.values())
group_nums = [self.group_dependency.get(module_name, 1) for module_name in sub_metrics.keys()]
max_group_nums = int(np.lcm.reduce(group_nums))
pruned_numel_per_group = int(fused_metric.numel() // max_group_nums * min_sparsity_rate)
group_step = fused_metric.shape[0] // max_group_nums
# get the public part of the mask of the module with dependencies
dependency_mask = torch.ones_like(fused_metric)
for gid in range(max_group_nums):
_start = gid * group_step
_end = (gid + 1) * group_step
if pruned_numel_per_group > 0:
threshold = torch.topk(fused_metric[_start: _end].reshape(-1), pruned_numel_per_group, largest=False)[0].max()
dependency_mask[_start: _end] = torch.gt(fused_metric[_start:_end], threshold).type_as(fused_metric)
# change the metric value corresponding to the public mask part to the minimum value
for module_name, targets_metric in sub_metrics.items():
if target_name in targets_metric:
# Following is Plan A, generate the dependency mask first, and then fill in the sparsity,
# the final mask is group unbalanced. - 1 ensure the denpendency metric is the minimum, and will be masked first.
# min_value = targets_metric[target_name].min() - 1
# metrics[module_name][target_name] = torch.where(dependency_mask!=0, targets_metric[target_name], min_value)
# Following is Plan B, just generate the dependency mask, the final mask is group balanced.
masks.setdefault(module_name, {})
masks[module_name][target_name] = self._expand_mask(module_name, target_name, dependency_mask)
# generate masks for layers without dependencies
for module_name, targets_metric in metrics.items():
masks.setdefault(module_name, {})
wrapper = self.pruner.get_modules_wrapper()[module_name]
for target_name, target_metric in targets_metric.items():
if target_name in masks[module_name]:
continue
sparsity_rate = wrapper.config['total_sparsity']
prune_num = int(sparsity_rate * target_metric.numel())
if prune_num != 0:
threshold = torch.topk(target_metric.reshape(-1), prune_num, largest=False)[0].max()
shrinked_mask = torch.gt(target_metric, threshold).type_as(target_metric)
else:
sub_mask = torch.ones_like(fused_metric[_start:_end])
sub_masks.append(sub_mask)
dependency_mask = torch.cat(sub_masks, dim=0)
# change the metric value corresponding to the public mask part to the minimum value
for module_name, target_metric in sub_metrics.items():
min_value = target_metric.min()
metrics[module_name] = torch.where(dependency_mask!=0, target_metric, min_value)
return super().common_target_masks_generation(metrics)
# target_metric should have the same size as shrinked_mask
shrinked_mask = torch.ones_like(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks
......@@ -51,7 +51,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
self.total_iteration = total_iteration
self.skip_first_iteration = skip_first_iteration
super().__init__(origin_model, origin_config_list=origin_config_list, origin_masks=origin_masks,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='latest')
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_iteration = 1 if self.skip_first_iteration else 0
......@@ -78,10 +78,14 @@ class FunctionBasedTaskGenerator(TaskGenerator):
# get current2origin_sparsity and compact2origin_sparsity
origin_model = torch.load(self._origin_model_path)
current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.target_sparsity)
_logger.debug('\nTask %s total real sparsity compared with original model is:\n%s', str(task_result.task_id), json_tricks.dumps(current2origin_sparsity, indent=4))
current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(origin_model, compact_model, compact_model_masks,
self.target_sparsity)
debug_msg = f'\nTask {task_result.task_id} total real sparsity compared with original model is:\n' + \
f'{json_tricks.dumps(current2origin_sparsity, indent=4)}'
_logger.debug(debug_msg)
if task_result.task_id != 'origin':
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
task = self._tasks[task_result.task_id]
task.state['current2origin_sparsity'] = current2origin_sparsity
# if reach the total_iteration, no more task will be generated
if self.current_iteration > self.total_iteration:
......@@ -116,7 +120,8 @@ class AGPTaskGenerator(FunctionBasedTaskGenerator):
for target, mo in zip(target_sparsity, compact2origin_sparsity):
ori_sparsity = (1 - (1 - iteration / self.total_iteration) ** 3) * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
assert 0 <= sparsity <= 1, err_msg
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
......@@ -128,7 +133,8 @@ class LinearTaskGenerator(FunctionBasedTaskGenerator):
for target, mo in zip(target_sparsity, compact2origin_sparsity):
ori_sparsity = iteration / self.total_iteration * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
assert 0 <= sparsity <= 1, err_msg
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
......@@ -149,16 +155,18 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
# The following is the formula in paper.
# ori_sparsity = (target['total_sparsity'] * 100) ** (iteration / self.total_iteration) / 100
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
assert 0 <= sparsity <= 1, err_msg
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
class SimulatedAnnealingTaskGenerator(TaskGenerator):
def __init__(self, origin_model: Optional[Module], origin_config_list: Optional[List[Dict]], origin_masks: Dict[str, Dict[str, Tensor]] = {},
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
def __init__(self, origin_model: Optional[Module], origin_config_list: Optional[List[Dict]],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, start_temperature: float = 100, stop_temperature: float = 20,
cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35, log_dir: Union[str, Path] = '.',
keep_intermediate_result: bool = False):
"""
Parameters
----------
......@@ -188,7 +196,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
self.perturbation_magnitude = perturbation_magnitude
super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='maximize')
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_temperature = self.start_temperature
......@@ -196,7 +204,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
# TODO: replace with validation here
for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
warn_msg = 'Only `total_sparsity` can be differentially allocated sparse ratio to each layer, ' + \
'`sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. ' + \
'Make sure you know what this will lead to, otherwise please use `total_sparsity`.'
_logger.warning(warn_msg)
self.weights_numel, self.masked_rate = get_model_weights_numel(model, config_list, masks)
self.target_sparsity_list = config_list_canonical(model, config_list)
......@@ -259,11 +270,11 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
num_weights = sorted([self.weights_numel[op_name] for op_name in op_names])
sparsity = sorted(random_sparsity)
# calculate the scale
total_weights = np.sum(num_weights)
total_weights_pruned = np.sum([int(num_weight * sparsity[idx]) for idx, num_weight in enumerate(num_weights)])
if total_weights_pruned == 0:
return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment