Unverified Commit ed455174 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] evaluator - step2 (#4992)

parent a689e619
......@@ -5,3 +5,9 @@ Quickstart
PyTorch </tutorials/hpo_quickstart_pytorch/main>
TensorFlow </tutorials/hpo_quickstart_tensorflow/main>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_pytorch/index
/tutorials/hpo_quickstart_tensorflow/index
Evaluator
=========
.. _compression-torch-evaluator:
TorchEvaluator
--------------
.. autoclass:: nni.compression.pytorch.TorchEvaluator
.. _compression-lightning-evaluator:
LightningEvaluator
------------------
.. autoclass:: nni.compression.pytorch.LightningEvaluator
......@@ -8,5 +8,6 @@ Compression API Reference
Quantizer <quantizer>
Pruning Speedup <pruning_speedup>
Quantization Speedup <quantization_speedup>
Evaluator <evaluator>
Compression Utilities <utils>
Framework Related <framework>
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .utils import LightningEvaluator, TorchEvaluator
......@@ -119,7 +119,8 @@ class Compressor:
Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
if self._modules_to_compress is None:
self._modules_to_compress = []
......@@ -146,7 +147,8 @@ class Compressor:
Optional[Dict]
The retrieved configuration for this layer, if None, this layer should not be compressed.
"""
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, err_msg
ret = None
for config in self.config_list:
......@@ -240,8 +242,10 @@ class Compressor:
Dict[int, List[str]]
A dict. The key is the config idx in config_list, the value is the module name list. i.e., {1: ['layer.0', 'layer.2']}.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, err_msg
self._unwrap_model()
module_groups = {}
......@@ -323,6 +327,8 @@ class Compressor:
torch.nn.Module
model with specified modules compressed.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, err_msg
return self.bound_model
......@@ -43,8 +43,8 @@ class PrunerModuleWrapper(Module):
pruning_target_mask_name = '{}_mask'.format(pruning_target_name)
pruning_target = getattr(self.module, pruning_target_name, None)
if hasattr(self.module, pruning_target_name) and pruning_target is not None:
setattr(self, pruning_target_name, Parameter(torch.empty(pruning_target.shape)))
self.register_buffer(pruning_target_mask_name, torch.ones(pruning_target.shape))
setattr(self, pruning_target_name, Parameter(torch.empty_like(pruning_target)))
self.register_buffer(pruning_target_mask_name, torch.ones_like(pruning_target))
else:
self.register_buffer(pruning_target_mask_name, None)
......@@ -67,11 +67,11 @@ class PrunerModuleWrapper(Module):
The best place to call this function is in `Pruner._unwrap_model()`.
"""
delattr(self.module, 'weight')
self.module.weight = Parameter(torch.empty(self.weight.size()))
self.module.weight = Parameter(torch.empty_like(self.weight))
self.module.weight.data = torch.mul(self.weight, self.weight_mask)
if hasattr(self.module, 'bias') and self.module.bias is not None:
delattr(self.module, 'bias')
self.module.bias = Parameter(torch.empty(self.bias.size()))
self.module.bias = Parameter(torch.empty_like(self.bias))
self.module.bias.data = torch.mul(self.bias, self.bias_mask)
def forward(self, *inputs):
......@@ -130,7 +130,8 @@ class Pruner(Compressor):
Wrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
if not self.is_wrapped:
for _, wrapper in reversed(list(self.get_modules_wrapper().items())):
......@@ -143,7 +144,8 @@ class Pruner(Compressor):
Unwrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
"""
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
if self.is_wrapped:
for wrapper in self.get_modules_wrapper().values():
......@@ -165,8 +167,10 @@ class Pruner(Compressor):
self._unwrap_model()
parameter_name_map = {}
for name, param in self.bound_model.named_parameters():
# If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`, the name will not change after wrap.
# If the parameter name in under wrapped module is others, the name `xxx.param` will change to `xxx.module.param` after wrap.
# If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`,
# the name will not change after wrap.
# If the parameter name in under wrapped module is others,
# the name `xxx.param` will change to `xxx.module.param` after wrap.
parameter_name_map[name] = wrapped_param_names[id(param)] if id(param) in wrapped_param_names else name
self._wrap_model()
return parameter_name_map
......@@ -183,14 +187,12 @@ class Pruner(Compressor):
The masks dict with format {'op_name': {'weight': mask, 'bias': mask}}.
"""
wrappers = self.get_modules_wrapper()
for name, layer_mask in masks.items():
assert name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(name)
if layer_mask.get('weight') is not None:
assert hasattr(wrappers[name], 'weight_mask'), 'There is no attribute weight_mask in wrapper.'
setattr(wrappers[name], 'weight_mask', layer_mask.get('weight'))
if layer_mask.get('bias') is not None:
assert hasattr(wrappers[name], 'bias_mask'), 'There is no attribute bias_mask in wrapper.'
setattr(wrappers[name], 'bias_mask', layer_mask.get('bias'))
for module_name, target_masks in masks.items():
assert module_name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(module_name)
for target_name, target_mask in target_masks.items():
assert hasattr(wrappers[module_name], f'{target_name}_mask'), f'There is no attribute {target_name}_mask in wrapper.'
target: Tensor = getattr(self.get_modules_wrapper()[module_name], target_name)
setattr(wrappers[module_name], f'{target_name}_mask', target_mask.to(target.device))
def compress(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]]]:
"""
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Callable, Optional, cast
from typing import Dict, List, Callable, Optional, cast, overload
import json_tricks
import torch
......@@ -11,12 +13,13 @@ from torch import Tensor
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity, config_list_canonical
from nni.compression.pytorch.utils import count_flops_params
from .iterative_pruner import IterativePruner, PRUNER_DICT
from .tools import TaskGenerator
from .tools.rl_env import DDPG, AMCEnv
from ..utils import LightningEvaluator, TorchEvaluator, compute_sparsity, config_list_canonical
from ..utils.docstring import _EVALUATOR_DOCSTRING
class AMCTaskGenerator(TaskGenerator):
......@@ -41,8 +44,8 @@ class AMCTaskGenerator(TaskGenerator):
ddpg_params
The ddpg agent parameters.
target : str
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse, but in AMC, you can choose flops sparse.
This parameter is used to explain what the sparsity setting in config_list refers to.
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse,
but in AMC, you can choose flops sparse. This parameter is used to explain what the sparsity setting in config_list refers to.
"""
def __init__(self, total_episode: int, dummy_input: Tensor, origin_model: Module, origin_config_list: List[Dict],
......@@ -56,7 +59,7 @@ class AMCTaskGenerator(TaskGenerator):
self.config_list_copy = deepcopy(origin_config_list)
super().__init__(origin_model=origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='maximize')
def init_pending_tasks(self) -> List[Task]:
origin_model = torch.load(self._origin_model_path)
......@@ -82,6 +85,8 @@ class AMCTaskGenerator(TaskGenerator):
return self.generate_tasks(task_result)
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
self.temp_config_list = self.temp_config_list if hasattr(self, 'temp_config_list') else []
# append experience & update agent policy
if self.action is not None:
action, reward, observation, done = self.env.step(self.action, task_result.compact_model)
......@@ -106,7 +111,8 @@ class AMCTaskGenerator(TaskGenerator):
origin_model = torch.load(self._origin_model_path)
compact_model = task_result.compact_model
compact_model_masks = task_result.compact_model_masks
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.temp_config_list)
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks,
self.temp_config_list) # type: ignore
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.config_list_copy)
self._tasks[task_result.task_id].state['current_total_sparsity'] = current2origin_sparsity
......@@ -162,7 +168,7 @@ class AMCTaskGenerator(TaskGenerator):
class AMCPruner(IterativePruner):
r"""
__doc__ = r"""
AMC pruner leverages reinforcement learning to provide the model compression policy.
According to the author, this learning-based compression policy outperforms conventional rule-based compression policy by having a higher compression ratio,
better preserving the accuracy and freeing human labor.
......@@ -186,10 +192,11 @@ class AMCPruner(IterativePruner):
- op_names : Operation name to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
dummy_input : torch.Tensor
`dummy_input` is required for speedup and tracing the model in RL environment.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
pruning_algorithm : str
Supported pruning algorithm ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
......@@ -197,8 +204,6 @@ class AMCPruner(IterativePruner):
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
ddpg_params : Dict
Configuration dict to configure the DDPG agent, any key unset will be set to default implicitly.
- hidden1: hidden num of first fully connect layer. Default: 300
......@@ -223,23 +228,42 @@ class AMCPruner(IterativePruner):
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse, but in AMC, you can choose flops sparse.
This parameter is used to explain what the sparsity setting in config_list refers to.
Examples
--------
>>> from nni.compression.pytorch.pruning import AMCPruner
>>> config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.5, 'max_sparsity_per_layer': 0.8}]
>>> dummy_input = torch.rand(...).to(device)
>>> evaluator = ...
>>> finetuner = ...
>>> pruner = AMCPruner(400, model, config_list, dummy_input, evaluator, finetuner=finetuner)
>>> pruner.compress()
Notes
-----
The full script can be found :githublink:`here <examples/model_compress/pruning/amc_pruning_torch.py>`.
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator,
pruning_algorithm: str = 'l1', log_dir: str = '.', keep_intermediate_result: bool = False,
ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'):
...
@overload
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], dummy_input: Tensor,
evaluator: Callable[[Module], float], pruning_algorithm: str = 'l1', log_dir: str = '.',
keep_intermediate_result: bool = False, finetuner: Optional[Callable[[Module], None]] = None,
ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'):
...
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], *args, **kwargs):
new_api = ['evaluator', 'pruning_algorithm', 'log_dir', 'keep_intermediate_result', 'ddpg_params', 'pruning_params', 'target']
new_init_kwargs = {'pruning_algorithm': 'l1', 'log_dir': '.', 'keep_intermediate_result': False,
'ddpg_params': {}, 'pruning_params': {}, 'target': 'flops'}
old_api = ['dummy_input', 'evaluator', 'pruning_algorithm', 'log_dir', 'keep_intermediate_result', 'finetuner', 'ddpg_params',
'pruning_params', 'target']
old_init_kwargs = {'pruning_algorithm': 'l1', 'log_dir': '.', 'keep_intermediate_result': False, 'finetuner': None,
'ddpg_params': {}, 'pruning_params': {}, 'target': 'flops'}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
pruning_algorithm = init_kwargs['pruning_algorithm']
log_dir = init_kwargs['log_dir']
keep_intermediate_result = init_kwargs['keep_intermediate_result']
ddpg_params = init_kwargs['ddpg_params']
pruning_params = init_kwargs['pruning_params']
target = init_kwargs['target']
dummy_input = self.dummy_input if not self.using_evaluator else self.evaluator.get_dummy_input()
assert pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo'], \
"Only support pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo']"
task_generator = AMCTaskGenerator(total_episode=total_episode,
......@@ -251,5 +275,9 @@ class AMCPruner(IterativePruner):
ddpg_params=ddpg_params,
target=target)
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=True, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=True, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=True, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=False) # type: ignore
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Callable, Optional
from typing import Dict, List, Callable, Optional, overload
from torch import Tensor
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
from .basic_pruner import ADMMPruner
from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner
from .tools import LotteryTicketTaskGenerator
from ..utils import LightningEvaluator, TorchEvaluator, OptimizerConstructHelper
from ..utils.docstring import _EVALUATOR_DOCSTRING
_logger = logging.getLogger(__name__)
......@@ -21,10 +23,7 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, sa_params: Dict = {}, log_dir: str = '.',
keep_intermediate_result: bool = False):
self.iterative_pruner = SimulatedAnnealingPruner(model=None,
config_list=None,
log_dir=Path(log_dir, 'SA'),
**sa_params)
self._sa_params = sa_params
super().__init__(total_iteration=total_iteration,
origin_model=origin_model,
origin_config_list=origin_config_list,
......@@ -36,12 +35,20 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
# TODO: replace with validation here
for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
warn_msg = 'Only `total_sparsity` can be differentially allocated sparse ratio to each layer, ' + \
'`sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. ' + \
'Make sure you know what this will lead to, otherwise please use `total_sparsity`.'
_logger.warning(warn_msg)
return super().reset(model, config_list, masks)
def _iterative_pruner_reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.iterative_pruner.task_generator._log_dir = Path(self._log_dir_root, 'SA')
self.iterative_pruner.reset(model, config_list=config_list, masks=masks)
if not hasattr(self, 'iterative_pruner'):
self.iterative_pruner = SimulatedAnnealingPruner(model=model,
config_list=config_list,
log_dir=Path(self._log_dir_root, 'SA'),
**self._sa_params)
else:
self.iterative_pruner.reset(model, config_list=config_list, masks=masks)
def allocate_sparsity(self, new_config_list: List[Dict], model: Module, masks: Dict[str, Dict[str, Tensor]]):
self._iterative_pruner_reset(model, new_config_list, masks)
......@@ -53,8 +60,9 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
class AutoCompressPruner(IterativePruner):
r"""
__doc__ = r"""
For total iteration number :math:`N`, AutoCompressPruner prune the model that survive the previous iteration for a fixed sparsity ratio (e.g., :math:`1-{(1-0.8)}^{(1/N)}`) to achieve the overall sparsity (e.g., :math:`0.8`):
""" + r"""
.. code-block:: bash
......@@ -65,35 +73,27 @@ class AutoCompressPruner(IterativePruner):
Parameters
----------
model : Module
model
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
config_list
The origin config list provided by the user.
total_iteration : int
total_iteration
The total iteration number.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
admm_params : Dict
admm_params
The parameters passed to the ADMMPruner.
- trainer : Callable[[Module, Optimizer, Callable].
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
- traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
- criterion : Callable[[Tensor, Tensor], Tensor].
The criterion function used in trainer. Take model output and target value as input, and return the loss.
- evaluator : LightningEvaluator or TorchEvaluator.
The same with the evaluator of AutoCompressPruner input parameter.
- iterations : int.
The total iteration number in admm pruning algorithm.
- training_epochs : int.
The epoch number for training model in each iteration.
sa_params : Dict
sa_params
The parameters passed to the SimulatedAnnealingPruner.
- evaluator : Callable[[Module], float]. Required.
Evaluate the pruned model and give a score.
- evaluator : LightningEvaluator or TorchEvaluator.
The same with the evaluator of AutoCompressPruner input parameter.
- start_temperature : float. Default: `100`.
Start temperature of the simulated annealing process.
- stop_temperature : float. Default: `20`.
......@@ -104,54 +104,50 @@ class AutoCompressPruner(IterativePruner):
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
- pruning_algorithm : str. Default: `'level'`.
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
- pruning_params : Dict. Default: `{}`.
- pruning_params : Dict. Default: dict().
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
log_dir : str
log_dir
The log directory used to save the result, you can find the best result under this folder.
keep_intermediate_result : bool
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handles all finetune logic, takes a pytorch module as input.
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
speedup : bool
evaluator
``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
{evaluator_docstring}
The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
Examples
--------
>>> import nni
>>> from nni.compression.pytorch.pruning import AutoCompressPruner
>>> model = ...
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
>>> trainer = ...
>>> criterion = ...
>>> evaluator = ...
>>> finetuner = ...
>>> admm_params = {
>>> 'trainer': trainer,
>>> 'traced_optimizer': traced_optimizer,
>>> 'criterion': criterion,
>>> 'iterations': 10,
>>> 'training_epochs': 1
>>> }
>>> sa_params = {
>>> 'evaluator': evaluator
>>> }
>>> pruner = AutoCompressPruner(model, config_list, 10, admm_params, sa_params, finetuner=finetuner)
>>> pruner.compress()
>>> _, model, masks, _, _ = pruner.get_best_result()
Notes
-----
The full script can be found :githublink:`here <examples/model_compress/pruning/auto_compress_pruner.py>`.
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False,
dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None):
...
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
*args, **kwargs):
new_api = ['evaluator', 'speedup']
new_init_kwargs = {'evaluator': None, 'speedup': False}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
speedup = init_kwargs['speedup']
task_generator = AutoCompressTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
......@@ -175,6 +171,10 @@ class AutoCompressPruner(IterativePruner):
else:
admm_params['granularity'] = 'fine-grained'
pruner = ADMMPruner(None, None, **admm_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
pruner = ADMMPruner(None, None, **admm_params) # type: ignore
if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=False) # type: ignore
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy
from typing import Dict, List, Tuple, Callable, Optional, Union
import logging
from typing import Any, Dict, List, Tuple, Callable, Optional, Union, overload
import torch
from torch import Tensor
......@@ -12,9 +15,63 @@ from nni.algorithms.compression.v2.pytorch.base import Pruner, BasePruningSchedu
from nni.compression.pytorch.speedup import ModelSpeedup
from .tools import TaskGenerator
from ..utils import Evaluator, LightningEvaluator, TorchEvaluator
_logger = logging.getLogger(__name__)
_LEGACY_FINETUNER = Callable[[Module], None]
_LEGACY_EVALUATOR = Callable[[Module], float]
class PruningScheduler(BasePruningScheduler):
# TODO: remove in nni v3.0.
class EvaluatorBasedPruningScheduler(BasePruningScheduler):
evaluator: LightningEvaluator | TorchEvaluator
using_evaluator: bool
finetuner: _LEGACY_FINETUNER
_evaluator: _LEGACY_EVALUATOR
dummy_input: Any
def _init_evaluator(self, model: Module, new_api: List[str], new_init_kwargs: Dict, old_api: List[str],
old_init_kwargs: Dict, args: Tuple, kwargs: Dict) -> Dict:
# for fake __init__ overload, parsing args and kwargs,
# initializing evaluator or [finetuner, evaluator, dummy_input], return the remaining arguments.
if (len(args) > 0 and isinstance(args[0], Evaluator)) or \
(len(args) == 0 and isinstance(kwargs.get('evaluator', None), Evaluator)):
init_kwargs = self._parse_args(new_api, args, kwargs, new_init_kwargs)
self.evaluator: LightningEvaluator | TorchEvaluator = init_kwargs.pop('evaluator')
if not self.evaluator._initialization_complete:
self.evaluator._init_optimizer_helpers(model) # type: ignore
self.using_evaluator = True
else:
init_kwargs = self._parse_args(old_api, args, kwargs, old_init_kwargs)
self.finetuner: _LEGACY_FINETUNER = init_kwargs.pop('finetuner')
self._evaluator: _LEGACY_EVALUATOR = init_kwargs.pop('evaluator')
self.dummy_input = init_kwargs.pop('dummy_input')
self.using_evaluator = False
warn_msg = f'The old API ...{",".join(old_api)} will be deprecated after NNI v3.0,' +\
f'please using the new one ...{",".join(new_api)}'
_logger.warning(warn_msg)
return init_kwargs
def _parse_args(self, arg_names: List, args: Tuple, kwargs: Dict, def_kwargs: Dict) -> Dict:
merged_kwargs = {arg_names[idx]: arg for idx, arg in enumerate(args)}
for key, value in kwargs.items():
if key in merged_kwargs:
raise TypeError(f"{self.__class__.__name__}.__init__() got multiple values for argument '{key}'")
merged_kwargs[key] = value
for key, value in def_kwargs.items():
if key not in merged_kwargs:
merged_kwargs[key] = value
diff = set(arg_names).difference(merged_kwargs.keys())
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
diff = set(merged_kwargs.keys()).difference(arg_names)
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() got {len(diff)} unexpected keyword argument: {diff}")
return merged_kwargs
class PruningScheduler(EvaluatorBasedPruningScheduler):
"""
Parameters
----------
......@@ -25,7 +82,8 @@ class PruningScheduler(BasePruningScheduler):
Used to generate task for each iteration.
finetuner
The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise.
It will be called at the end of each iteration if reset_weight is False,
will be called at the beginning of each iteration otherwise.
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input
......@@ -36,16 +94,30 @@ class PruningScheduler(BasePruningScheduler):
reset_weight
If set True, the model weight will reset to the origin model weight at the end of each iteration step.
"""
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Optional[Callable[[Module], None]] = None,
speedup: bool = False, dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None,
@overload
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, evaluator: LightningEvaluator | TorchEvaluator,
speedup: bool = False, reset_weight: bool = False):
...
@overload
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: _LEGACY_FINETUNER | None = None,
speedup: bool = False, dummy_input: Optional[Tensor] = None, evaluator: _LEGACY_EVALUATOR | None = None,
reset_weight: bool = False):
...
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, *args, **kwargs) -> None:
# TODO: remove in nni v3.0. Fake overload.
new_api = ['evaluator', 'speedup', 'reset_weight']
new_init_kwargs = {'evaluator': None, 'speedup': False, 'reset_weight': False}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'reset_weight']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False, 'reset_weight': False}
init_kwargs = self._init_evaluator(None, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs) # type: ignore
self.pruner = pruner
self.task_generator = task_generator
self.finetuner = finetuner
self.speedup = speedup
self.dummy_input = dummy_input
self.evaluator = evaluator
self.reset_weight = reset_weight
self.speedup = init_kwargs['speedup']
self.reset_weight = init_kwargs['reset_weight']
def reset(self, model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}):
self.task_generator.reset(model, config_list, masks)
......@@ -61,6 +133,7 @@ class PruningScheduler(BasePruningScheduler):
generate masks -> speedup -> finetune -> evaluate
"""
model, masks, config_list = task.load_data()
self.pruner.reset(model, config_list)
self.pruner.load_masks(masks)
......@@ -74,28 +147,58 @@ class PruningScheduler(BasePruningScheduler):
# speedup
if self.speedup and task.speedup:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
if self.using_evaluator:
ModelSpeedup(compact_model, self.evaluator.get_dummy_input(), pruner_generated_masks).speedup_model()
compact_model_masks = {}
else:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
# finetune
if self.finetuner is not None and task.finetune:
if self.speedup:
self.finetuner(compact_model)
else:
self.pruner._wrap_model()
self.finetuner(compact_model)
self.pruner._unwrap_model()
if self.using_evaluator:
if task.finetune:
self.evaluator.bind_model(compact_model) # type: ignore
if self.speedup:
self.evaluator.finetune()
else:
self.pruner._wrap_model()
self.evaluator.finetune()
self.pruner._unwrap_model()
self.evaluator.unbind_model()
else:
if self.finetuner is not None and task.finetune:
if self.speedup:
self.finetuner(compact_model)
else:
self.pruner._wrap_model()
self.finetuner(compact_model)
self.pruner._unwrap_model()
# evaluate
if self.evaluator is not None and task.evaluate:
if self.speedup:
score = self.evaluator(compact_model)
if self.using_evaluator:
if task.evaluate:
self.evaluator.bind_model(compact_model) # type: ignore
# TODO: support saving customized score
if self.speedup:
score = self.evaluator.evaluate()
else:
self.pruner._wrap_model()
score = self.evaluator.evaluate()
self.pruner._unwrap_model()
score = score[0] if isinstance(score, tuple) else score
self.evaluator.unbind_model()
else:
self.pruner._wrap_model()
score = self.evaluator(compact_model)
self.pruner._unwrap_model()
score = None
else:
score = None
if self._evaluator is not None and task.evaluate:
if self.speedup:
score = self._evaluator(compact_model) # type: ignore
else:
self.pruner._wrap_model()
score = self._evaluator(compact_model) # type: ignore
self.pruner._unwrap_model()
else:
score = None
# clear model references
self.pruner.clear_model_references()
......@@ -107,13 +210,20 @@ class PruningScheduler(BasePruningScheduler):
finetune -> generate masks -> reset weight -> speedup -> evaluate
"""
model, masks, config_list = task.load_data()
checkpoint = deepcopy(model.state_dict())
self.pruner.reset(model, config_list)
self.pruner.load_masks(masks)
# finetune
if self.finetuner is not None and task.finetune:
self.finetuner(model)
if self.using_evaluator:
if task.finetune:
self.evaluator.bind_model(model) # type: ignore
self.evaluator.finetune()
self.evaluator.unbind_model()
else:
if self.finetuner is not None and task.finetune:
self.finetuner(model)
# pruning model
compact_model, pruner_generated_masks = self.pruner.compress()
......@@ -128,19 +238,38 @@ class PruningScheduler(BasePruningScheduler):
# speedup
if self.speedup and task.speedup:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
if self.using_evaluator:
ModelSpeedup(compact_model, self.evaluator.get_dummy_input(), pruner_generated_masks).speedup_model()
compact_model_masks = {}
else:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
# evaluate
if self.evaluator is not None and task.evaluate:
if self.speedup:
score = self.evaluator(compact_model)
if self.using_evaluator:
if task.evaluate:
self.evaluator.bind_model(compact_model) # type: ignore
# TODO: support saving customized score
if self.speedup:
score = self.evaluator.evaluate()
else:
self.pruner._wrap_model()
score = self.evaluator.evaluate()
self.pruner._unwrap_model()
score = score[0] if isinstance(score, tuple) else score
self.evaluator.unbind_model()
else:
self.pruner._wrap_model()
score = self.evaluator(compact_model)
self.pruner._unwrap_model()
score = None
else:
score = None
if self._evaluator is not None and task.evaluate:
if self.speedup:
score = self._evaluator(compact_model) # type: ignore
else:
self.pruner._wrap_model()
score = self._evaluator(compact_model) # type: ignore
self.pruner._unwrap_model()
else:
score = None
# clear model references
self.pruner.clear_model_references()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy
import logging
from typing import Dict, List, Tuple, Callable
from typing import Dict, List, Tuple, Callable, overload
import torch
from torch import autograd, Tensor
......@@ -12,17 +14,23 @@ from torch.nn.parameter import Parameter
from torch.optim import Optimizer, Adam
from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper, LayerInfo
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema, OptimizerConstructHelper
from nni.common.serializer import Traceable
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import EvaluatorBasedPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema
from .tools.base import TrainerBasedDataCollector
from .tools.base import EvaluatorBasedDataCollector, TrainerBasedDataCollector
from .tools import (
StraightMetricsCalculator,
NormalSparsityAllocator
NormalSparsityAllocator,
StraightMetricsCalculator
)
from ..utils import (
LightningEvaluator,
TorchEvaluator
)
from ..utils.docstring import _EVALUATOR_DOCSTRING
_logger = logging.getLogger(__name__)
......@@ -47,8 +55,7 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper):
def forward(self, *inputs):
# apply mask to weight, bias
# NOTE: I don't know why training getting slower and slower if only `self.weight_mask` without `detach()`
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask.detach())) # type: ignore
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask)) # type: ignore
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.module.bias = torch.mul(self.bias, self.bias_mask) # type: ignore
return self.module(*inputs)
......@@ -77,13 +84,30 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {}
target_name = 'weight'
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight_score.data # type: ignore
data[wrapper.name] = {target_name: wrapper.weight_score.data} # type: ignore
return data
class MovementPruner(BasicPruner):
r"""
class EvaluatorBasedScoreDataCollector(EvaluatorBasedDataCollector):
"""
Collect all weight_score in wrappers as data used to calculate metrics.
"""
def collect(self) -> Dict[str, Tensor]:
assert self.compressor.bound_model is not None
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
data = {}
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target_score: Tensor = getattr(wrapper, f'{target_name}_score')
data[module_name] = {target_name: target_score.data.clone()}
return data
class MovementPruner(EvaluatorBasedPruner):
__doc__ = r"""
Movement pruner is an implementation of movement pruning.
This is a "fine-pruning" algorithm, which means the masks may change during each fine-tuning step.
Each weight element will be scored by the opposite of the sum of the product of weight and its gradient during each step.
......@@ -110,30 +134,12 @@ class MovementPruner(BasicPruner):
- op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer : Callable[[Module, Optimizer, Callable]
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss.
evaluator
``evaluator`` is used to replace the previous ``trainer``, ``traced_optimizer`` and ``criterion`` API.
{evaluator_docstring}
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
training_epochs : int
The total epoch number for training the model.
Make sure the total `optimizer.step()` in `training_epochs` is bigger than `cool_down_beginning_step`.
......@@ -145,33 +151,31 @@ class MovementPruner(BasicPruner):
The sparsity after each `optimizer.step()` is:
total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3).
Examples
--------
>>> import nni
>>> from nni.compression.pytorch.pruning import MovementPruner
>>> model = ...
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
>>> trainer = ...
>>> criterion = ...
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
>>> pruner = MovementPruner(model, config_list, trainer, traced_optimizer, criterion, 10, 3000, 27000)
>>> masked_model, masks = pruner.compress()
Notes
-----
For detailed example please refer to :githublink:`examples/model_compress/pruning/movement_pruning_glue.py <examples/model_compress/pruning/movement_pruning_glue.py>`
"""
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, training_epochs: int,
warm_up_step: int, cool_down_beginning_step: int):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, warm_up_step: int,
cool_down_beginning_step: int):
self.trainer = trainer
if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer
else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion
self.training_epochs = training_epochs
self.warm_up_step = warm_up_step
self.cool_down_beginning_step = cool_down_beginning_step
...
def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
# TODO: remove in nni v3.0. Fake overload.
new_api = ['evaluator', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step']
old_api = ['trainer', 'traced_optimizer', 'criterion', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step']
init_kwargs = self._init_evaluator(model, new_api, old_api, {}, args, kwargs)
self.training_epochs: int = init_kwargs['training_epochs']
self.warm_up_step: int = init_kwargs['warm_up_step']
self.cool_down_beginning_step: int = init_kwargs['cool_down_beginning_step']
assert self.warm_up_step < self.cool_down_beginning_step, '`warm_up_step` should smaller than `cool_down_beginning_step`'
super().__init__(model, config_list)
......@@ -184,14 +188,16 @@ class MovementPruner(BasicPruner):
if self.warm_up_step < current_step <= self.cool_down_beginning_step:
wrapper_dict = self.get_modules_wrapper()
for config in self.config_list:
current_sparsity = config['total_sparsity'] * (1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3)
scale = 1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3
current_sparsity = config['total_sparsity'] * scale
for op_name in config['op_names']:
wrapper_dict[op_name].config['total_sparsity'] = current_sparsity
wrapper = wrapper_dict[op_name]
wrapper.config['total_sparsity'] = current_sparsity
def reset_tools(self):
if self.metrics_calculator is None:
if not hasattr(self, 'metrics_calculator'):
self.metrics_calculator = StraightMetricsCalculator()
if self.sparsity_allocator is None:
if not hasattr(self, 'sparsity_allocator'):
self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False)
# use Adam to update the weight_score
......@@ -208,16 +214,30 @@ class MovementPruner(BasicPruner):
if self.step_counter > self.warm_up_step:
self.cubic_schedule(self.step_counter)
data = {}
target_name = 'weight'
for wrapper in self.get_modules_wrapper().values():
data[wrapper.name] = wrapper.weight_score.data
data[wrapper.name] = {target_name: wrapper.weight_score.data}
metrics = self.metrics_calculator.calculate_metrics(data) # type: ignore
masks = self.sparsity_allocator.generate_sparsity(metrics) # type: ignore
self.load_masks(masks)
if self.data_collector is None:
self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion, self.training_epochs, opt_after_tasks=[_optimizer_patch])
if self.using_evaluator:
# TODO: move to other place in nni v3.0
self.evaluator.unbind_model()
self.evaluator.bind_model(self.bound_model, self.get_origin2wrapped_parameter_name_map()) # type: ignore
if not hasattr(self, 'data_collector'):
self.data_collector = EvaluatorBasedScoreDataCollector(self, self.evaluator,
after_opt_step_tasks=[_optimizer_patch],
max_epochs=self.training_epochs)
else:
self.data_collector.reset(after_opt_step_tasks=[_optimizer_patch])
else:
self.data_collector.reset()
if not hasattr(self, 'data_collector'):
self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper,
self.criterion, self.training_epochs,
opt_after_tasks=[_optimizer_patch])
else:
self.data_collector.reset()
def _wrap_modules(self, layer: LayerInfo, config: Dict):
"""
......@@ -243,7 +263,6 @@ class MovementPruner(BasicPruner):
for wrapper in self.get_modules_wrapper().values():
wrapper.config['total_sparsity'] = 0
result = super().compress()
# del weight_score
for wrapper in self.get_modules_wrapper().values():
wrapper.weight_score = None
if self.using_evaluator:
self.evaluator.unbind_model()
return result
......@@ -8,6 +8,12 @@ from .base import (
SparsityAllocator,
TaskGenerator
)
from .data_collector import (
TargetDataCollector,
EvaluatorBasedTargetDataCollector,
EvaluatorBasedHookDataCollector
)
# TODO: remove in nni v3.0.
from .data_collector import (
WeightDataCollector,
WeightTrainerBasedDataCollector,
......@@ -16,7 +22,7 @@ from .data_collector import (
from .metrics_calculator import (
StraightMetricsCalculator,
NormMetricsCalculator,
MultiDataNormMetricsCalculator,
HookDataNormMetricsCalculator,
DistMetricsCalculator,
APoZRankMetricsCalculator,
MeanRankMetricsCalculator
......
......@@ -6,7 +6,7 @@ from datetime import datetime
import logging
from pathlib import Path
import types
from typing import List, Dict, Tuple, Optional, Callable, Union
from typing import List, Dict, Literal, Tuple, Optional, Callable, Union
import json_tricks
import torch
......@@ -15,7 +15,7 @@ from torch.nn import Module
from torch.optim import Optimizer
from ...base import Pruner, LayerInfo, Task, TaskResult
from ...utils import OptimizerConstructHelper, Scaling
from ...utils import Evaluator, Hook, OptimizerConstructHelper, Scaling
_logger = logging.getLogger(__name__)
......@@ -45,7 +45,7 @@ class DataCollector:
def __init__(self, compressor: Pruner):
self.compressor = compressor
def reset(self):
def reset(self, *args, **kwargs):
"""
Reset the `DataCollector`.
"""
......@@ -63,9 +63,12 @@ class DataCollector:
raise NotImplementedError()
# TODO: remove in nni v3.0.
COLLECTOR_TYPE = Union[Callable[[List, Tensor], Callable[[Tensor], None]], Callable[[List], Callable[[Module, Tensor, Tensor], None]]]
class HookCollectorInfo:
def __init__(self, targets: Union[Dict[str, Tensor], List[LayerInfo]], hook_type: str,
collector: Union[Callable[[List, Tensor], Callable[[Tensor], None]], Callable[[List], Callable[[Module, Tensor, Tensor], None]]]):
collector: COLLECTOR_TYPE):
"""
This class used to aggregate the information of what kind of hook is placed on which layers.
......@@ -76,23 +79,24 @@ class HookCollectorInfo:
hook_type
'forward' or 'backward'.
collector
A hook function generator, the input is a buffer (empty list) or a buffer (empty list) and tensor, the output is a hook function.
The buffer is used to store the data wanted to hook.
A hook function generator, the input is a buffer (empty list) or a buffer (empty list) and tensor,
the output is a hook function. The buffer is used to store the data wanted to hook.
"""
self.targets = targets
self.hook_type = hook_type
self.collector = collector
# TODO: remove in nni v3.0.
class TrainerBasedDataCollector(DataCollector):
"""
This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks.
"""
def __init__(self, compressor: Pruner, trainer: Callable[[Module, Optimizer, Callable], None], optimizer_helper: OptimizerConstructHelper,
criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
opt_before_tasks: List = [], opt_after_tasks: List = [],
collector_infos: List[HookCollectorInfo] = [], criterion_patch: Optional[Callable[[Callable], Callable]] = None):
def __init__(self, compressor: Pruner, trainer: Callable[[Module, Optimizer, Callable], None],
optimizer_helper: OptimizerConstructHelper, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
opt_before_tasks: List = [], opt_after_tasks: List = [], collector_infos: List[HookCollectorInfo] = [],
criterion_patch: Optional[Callable[[Callable], Callable]] = None):
"""
Parameters
----------
......@@ -252,6 +256,47 @@ class TrainerBasedDataCollector(DataCollector):
self._remove_hook(hook_id)
class EvaluatorBasedDataCollector(DataCollector):
"""
This data collector is the base class for the data collectors that want to use ``Evaluator`` to train or inference.
Three main usages are supported in this data collector:
1. Doing something before ``optimzer.step()`` and after ``optimzer.step()``. ``before_opt_step_tasks`` is a list of task functions
that will execute before ``optimzer.step()``. ``after_opt_step_tasks`` is a list of task functions that will execute after
``optimzer.step()``. All the task functions in the list should not have input arguments, function return value is allowed,
but ``Evaluator`` will not catch it.
2. Patch or modify the training loss. ``loss_patch`` is a function with input is the original loss and the output is the modified loss.
3. Add hooks on ``torch.nn.Module`` or ``Parameter`` or ``Buffer``. Three kinds of hook are supported, ``TensorHook``, ``ForwardHook``
and ``BackwardHook``. For initializing a ``Hook``, a hook function factory is needed, the factory function's input is an empty list,
and the output is a hook function defined by Pytorch.
Please refer `register_hook <https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html>`_,
`register_forward_hook <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook>`_,
`register_backward_hook <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_backward_hook>`_.
"""
def __init__(self, compressor: Pruner, evaluator: Evaluator, before_opt_step_tasks: List[Callable] | None = None,
after_opt_step_tasks: List[Callable] | None = None, loss_patch: Callable[[Tensor], Tensor] | None = None,
hooks: Dict[str, Dict[str, Hook]] | None = None, max_steps: int | None = None, max_epochs: int | None = None):
super().__init__(compressor)
self.evaluator = evaluator
self.max_steps = max_steps
self.max_epochs = max_epochs
self.reset(before_opt_step_tasks, after_opt_step_tasks, loss_patch, hooks)
def reset(self, before_opt_step_tasks: List[Callable] | None = None, after_opt_step_tasks: List[Callable] | None = None,
loss_patch: Callable[[Tensor], Tensor] | None = None, hooks: Dict[str, Dict[str, Hook]] | None = None):
if before_opt_step_tasks or after_opt_step_tasks:
before_opt_step_tasks = before_opt_step_tasks if before_opt_step_tasks else []
after_opt_step_tasks = after_opt_step_tasks if after_opt_step_tasks else []
self.evaluator.patch_optimizer_step(before_opt_step_tasks, after_opt_step_tasks)
if loss_patch:
self.evaluator.patch_loss(loss_patch)
if hooks:
self._hooks = hooks
hook_list = [hook for _ in hooks.values() for hook in _.values()]
self.evaluator.register_hooks(hook_list)
class MetricsCalculator:
"""
An abstract class for calculate a kind of metrics of the given data.
......@@ -260,7 +305,8 @@ class MetricsCalculator:
----------
scalers
Scaler is used to scale the metrics' size. It scaling metric to the same size as the shrinked mask in the sparsity allocator.
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`.
If you want to use different scalers for different pruning targets in different modules,
please use a dict `{module_name: {target_name: scaler}}`.
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
......@@ -268,7 +314,8 @@ class MetricsCalculator:
"""
def __init__(self, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None):
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers \
if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
def _get_scaler(self, module_name: str, target_name: str) -> Scaling:
scaler = _get_scaler(self.scalers, module_name, target_name)
......@@ -301,7 +348,8 @@ class SparsityAllocator:
scalers
Scaler is used to scale the masks' size. It shrinks the mask of the same size as the pruning target to the same size as the metric,
or expands the mask of the same size as the metric to the same size as the pruning target.
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`.
If you want to use different scalers for different pruning targets in different modules,
please use a dict `{module_name: {target_name: scaler}}`.
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
......@@ -313,7 +361,8 @@ class SparsityAllocator:
def __init__(self, pruner: Pruner, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None, continuous_mask: bool = True):
self.pruner = pruner
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers \
if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
self.continuous_mask = continuous_mask
def _get_scaler(self, module_name: str, target_name: str) -> Scaling | None:
......@@ -335,25 +384,39 @@ class SparsityAllocator:
mask = (scaler.shrink(mask) != 0).type_as(mask)
return mask
def _continuous_mask(self, new_masks: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
def _mask_metric(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
# Set the already masked part in the metric to the minimum value.
target_name = 'weight'
for module_name, targets_metric in metrics.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
shrinked_target_mask = self._shrink_mask(module_name, target_name, old_target_mask)
# make sure the masked position has the minimum metric
targets_metric[target_name] = targets_metric[target_name].to(shrinked_target_mask.device)
min_value = targets_metric[target_name].min() - 1
targets_metric[target_name] = torch.where(shrinked_target_mask != 0, targets_metric[target_name], min_value)
return metrics
def _continuous_mask(self, new_masks: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
# Set the already masked part to zero in the new_masks.
target_name = 'weight'
for module_name, target_mask in new_masks.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
old_target_mask = getattr(wrapper, f'{target_name}_mask', None)
old_target_mask: Tensor | None = getattr(wrapper, f'{target_name}_mask', None)
if old_target_mask is not None:
new_masks[module_name][target_name] = torch.min(target_mask[target_name], old_target_mask)
new_masks[module_name][target_name] = torch.min(target_mask[target_name],
old_target_mask.to(target_mask[target_name].device))
return new_masks
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
"""
Generate masks for metrics-dependent targets.
Parameters
----------
metrics
The format is {module_name: weight_metric}.
The metric of `weight` usually has the same size with shrinked mask.
The format is {module_name: {target_name: target_metric}}.
The metric of usually has the same size with shrinked mask.
Return
------
......@@ -384,7 +447,7 @@ class SparsityAllocator:
reduce_dims = [reduce_dim for reduce_dim in range(1, len(weight_mask.shape))]
# count unmasked number of values on dim 0 (output channel) of weight
unmasked_num_on_dim0 = weight_mask.sum(reduce_dims) if reduce_dims else weight_mask
module_masks['bias'] = (unmasked_num_on_dim0 != 0).type_as(old_bias_mask)
module_masks['bias'] = (unmasked_num_on_dim0 != 0).type_as(weight_mask)
return masks
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
......@@ -401,6 +464,8 @@ class SparsityAllocator:
Dict[str, Dict[str, Tensor]]
The masks format is {module_name: {target_name: mask}}.
"""
if self.continuous_mask:
metrics = self._mask_metric(metrics)
masks = self.common_target_masks_generation(metrics)
masks = self.special_target_masks_generation(masks)
if self.continuous_mask:
......@@ -425,11 +490,22 @@ class TaskGenerator:
The log directory use to saving the task generator log.
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
best_result_mode
The way to decide which one is the best result. Three modes are supported.
If the task results don't contain scores (task_result.score is None), it will fall back to ``latest``.
1. latest: The newest received result is the best result.
2. maximize: The one with largest task result score is the best result.
3. minimize: The one with smallest task result score is the best result.
"""
def __init__(self, origin_model: Optional[Module], origin_masks: Optional[Dict[str, Dict[str, Tensor]]] = {},
origin_config_list: Optional[List[Dict]] = [], log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
origin_config_list: Optional[List[Dict]] = [], log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
best_result_mode: Literal['latest', 'maximize', 'minimize'] = 'maximize'):
self._log_dir = log_dir
self._keep_intermediate_result = keep_intermediate_result
assert best_result_mode in ['latest', 'maximize', 'minimize'], f'Unsupported best_result_mode value: {best_result_mode}'
self._best_result_mode = best_result_mode
if origin_model is not None and origin_config_list is not None and origin_masks is not None:
self.reset(origin_model, origin_config_list, origin_masks)
......@@ -472,13 +548,24 @@ class TaskGenerator:
json_tricks.dump(config_list, f, indent=4)
def update_best_result(self, task_result: TaskResult):
score = task_result.score
task_id = task_result.task_id
task = self._tasks[task_id]
task.score = score
if self._best_score is None or (score is not None and score > self._best_score):
self._best_score = score
self._best_task_id = task_id
save_as_best_result = False
task = self._tasks[task_result.task_id]
task.score = task_result.score
if self._best_result_mode == 'latest':
self._best_task_id, save_as_best_result = task_result.task_id, True
if self._best_result_mode == 'maximize':
if self._best_score is None or (task.score is not None and task.score > self._best_score):
self._best_score = task.score
self._best_task_id, save_as_best_result = task_result.task_id, True
if self._best_result_mode == 'minimize':
if self._best_score is None or (task.score is not None and task.score < self._best_score):
self._best_score = task.score
self._best_task_id, save_as_best_result = task_result.task_id, True
if save_as_best_result:
with Path(task.config_list_path).open('r') as fr:
best_config_list = json_tricks.load(fr)
self._save_data('best_result', task_result.compact_model, task_result.compact_model_masks, best_config_list)
......
......@@ -6,13 +6,16 @@ from typing import Dict, List
from torch import Tensor
from .base import DataCollector, TrainerBasedDataCollector
from .base import DataCollector, EvaluatorBasedDataCollector
from .base import TrainerBasedDataCollector
_logger = logging.getLogger(__name__)
__all__ = ['WeightDataCollector', 'WeightTrainerBasedDataCollector', 'SingleHookTrainerBasedDataCollector']
__all__ = ['TargetDataCollector', 'EvaluatorBasedTargetDataCollector', 'EvaluatorBasedHookDataCollector',
'WeightDataCollector', 'WeightTrainerBasedDataCollector', 'SingleHookTrainerBasedDataCollector'] # TODO: remove in nni v3.0.
# TODO: remove in nni v3.0.
class WeightDataCollector(DataCollector):
"""
Collect all wrapper weights.
......@@ -21,40 +24,102 @@ class WeightDataCollector(DataCollector):
def reset(self):
pass
def collect(self) -> Dict[str, Tensor]:
def collect(self) -> Dict[str, Dict[str, Tensor]]:
data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight.data
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
# TODO: remove in nni v3.0.
class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
"""
Collect all wrapper weights after training or inference.
"""
def collect(self) -> Dict[str, Tensor]:
def collect(self) -> Dict[str, Dict[str, Tensor]]:
assert self.compressor.bound_model is not None
for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight.data
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
# TODO: remove in nni v3.0.
class SingleHookTrainerBasedDataCollector(TrainerBasedDataCollector):
"""
Add hooks and collect data during training or inference.
Single means each wrapper only has one hook to collect data.
"""
def collect(self) -> Dict[str, List[Tensor]]:
def collect(self) -> Dict[str, Dict[str, List[Tensor]]]:
assert self.compressor.bound_model is not None
for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {}
[data.update(buffer_dict) for _, buffer_dict in self._hook_buffer.items()]
target_name = 'weight'
for _, buffer_dict in self._hook_buffer.items():
for module_name, target_data in buffer_dict.items():
data[module_name] = {target_name: target_data}
return data
class TargetDataCollector(DataCollector):
"""
Collect all wrapper targets.
"""
def reset(self):
# No need to reset anything in this data collector.
pass
def collect(self) -> Dict[str, Dict[str, Tensor]]:
data = {}
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
class EvaluatorBasedTargetDataCollector(EvaluatorBasedDataCollector):
"""
Collect all wrapper pruning target after training or inference.
"""
def collect(self) -> Dict[str, Dict[str, Tensor]]:
assert self.compressor.bound_model is not None
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
data = {}
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
class EvaluatorBasedHookDataCollector(EvaluatorBasedDataCollector):
"""
Add hooks and collect data during training or inference.
NOTE: Only support one target has one hook right now.
"""
def collect(self) -> Dict[str, Dict[str, List]]:
assert self.compressor.bound_model is not None
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
data = {}
for module_name, hooks in self._hooks.items():
data[module_name] = {}
for target_name, hook in hooks.items():
data[module_name][target_name] = hook.buffer
return data
......@@ -11,7 +11,7 @@ from torch import Tensor
from .base import MetricsCalculator
from ...utils import Scaling
__all__ = ['NormMetricsCalculator', 'MultiDataNormMetricsCalculator', 'DistMetricsCalculator',
__all__ = ['NormMetricsCalculator', 'HookDataNormMetricsCalculator', 'DistMetricsCalculator',
'APoZRankMetricsCalculator', 'MeanRankMetricsCalculator', 'StraightMetricsCalculator']
......@@ -19,11 +19,12 @@ class StraightMetricsCalculator(MetricsCalculator):
"""
This metrics calculator directly returns a copy of data as metrics.
"""
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
metrics = {}
for name, tensor in data.items():
# use inplace detach `detach_` here to avoid creating a new tensor
metrics[name] = tensor.clone().detach_()
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
metrics[module_name][target_name] = target_data.clone().detach()
return metrics
......@@ -44,27 +45,32 @@ class NormMetricsCalculator(MetricsCalculator):
super().__init__(scalers=scalers)
self.p = p if p is not None else 'fro'
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor:
return t.norm(p=self.p, dim=-1) # type: ignore
metrics = {}
target_name = 'weight'
for module_name, target_data in data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics
class MultiDataNormMetricsCalculator(NormMetricsCalculator):
class HookDataNormMetricsCalculator(NormMetricsCalculator):
"""
The data value format is a two-element list [batch_number, cumulative_data].
The hook data value format is a two-element list [batch_number, cumulative_data].
Directly use the cumulative_data as new_data to calculate norm metric.
TaylorFO pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
new_data = {name: buffer[1] for name, buffer in data.items()}
def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
new_data = {}
for module_name, targets_data in data.items():
new_data[module_name] = {}
for target_name, (_, target_data) in targets_data.items():
new_data[module_name][target_name] = target_data
return super().calculate_metrics(new_data)
......@@ -85,7 +91,7 @@ class DistMetricsCalculator(MetricsCalculator):
super().__init__(scalers=scalers)
self.p = p if p is not None else 'fro'
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor:
reshape_data = t.reshape(-1, t.shape[-1])
metric = torch.zeros(reshape_data.shape[0], device=reshape_data.device)
......@@ -94,10 +100,11 @@ class DistMetricsCalculator(MetricsCalculator):
return metric.reshape(t.shape[:-1])
metrics = {}
target_name = 'weight'
for module_name, target_data in data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics
......@@ -108,16 +115,18 @@ class APoZRankMetricsCalculator(MetricsCalculator):
Note that the metric we return is (1 - apoz), because we assume a higher metric value has higher importance.
APoZRank pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor:
return 1 - t.mean(dim=-1)
metrics = {}
target_name = 'weight'
for module_name, target_data in data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics
......@@ -127,14 +136,15 @@ class MeanRankMetricsCalculator(MetricsCalculator):
This metric simply calculate the average on `self.dim`, then divide by the batch_number.
MeanRank pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor:
return t.mean(dim=-1)
metrics = {}
target_name = 'weight'
for module_name, target_data in data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics
......@@ -17,7 +17,8 @@ _logger = logging.getLogger(__name__)
class AMCEnv:
def __init__(self, model: Module, config_list: List[Dict], dummy_input: Tensor, total_sparsity: float, max_sparsity_per_layer: Dict[str, float], target: str = 'flops'):
def __init__(self, model: Module, config_list: List[Dict], dummy_input: Tensor, total_sparsity: float,
max_sparsity_per_layer: Dict[str, float], target: str = 'flops'):
pruning_op_names = []
[pruning_op_names.extend(config['op_names']) for config in config_list_canonical(model, config_list)]
self.pruning_ops = OrderedDict()
......@@ -26,7 +27,10 @@ class AMCEnv:
if name in pruning_op_names:
op_type = type(layer).__name__
stride = np.power(np.prod(layer.stride), 1 / len(layer.stride)) if hasattr(layer, 'stride') else 0 # type: ignore
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) if hasattr(layer, 'kernel_size') else 1 # type: ignore
if hasattr(layer, 'kernel_size'):
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) # type: ignore
else:
kernel_size = 1
self.pruning_ops[name] = (i, op_type, stride, kernel_size)
self.pruning_types.append(op_type)
self.pruning_types = list(set(self.pruning_types))
......@@ -60,15 +64,18 @@ class AMCEnv:
total_current_target = sum([current_statistics[name][self.target] for name in self.pruning_op_names])
previous_pruning_target = self.under_pruning_target - total_current_target
max_rest_pruning_target = sum([current_statistics[name][self.target] * self.max_sparsity_per_layer[name] for name in self.pruning_op_names[index + 1:]])
max_rest_pruning_target = sum([current_statistics[name][self.target] * self.max_sparsity_per_layer[name] \
for name in self.pruning_op_names[index + 1:]])
min_current_pruning_target = self.excepted_pruning_target - previous_pruning_target - max_rest_pruning_target
max_current_pruning_target_1 = self.origin_statistics[op_name][self.target] * self.max_sparsity_per_layer[op_name] - (self.origin_statistics[op_name][self.target] - current_statistics[op_name][self.target])
max_current_pruning_target_1 = self.origin_statistics[op_name][self.target] * self.max_sparsity_per_layer[op_name] - \
(self.origin_statistics[op_name][self.target] - current_statistics[op_name][self.target])
max_current_pruning_target_2 = self.excepted_pruning_target - previous_pruning_target
max_current_pruning_target = min(max_current_pruning_target_1, max_current_pruning_target_2)
min_action = min_current_pruning_target / current_statistics[op_name][self.target]
max_action = max_current_pruning_target / current_statistics[op_name][self.target]
if min_action > self.max_sparsity_per_layer[op_name]:
_logger.warning('[%s] min action > max sparsity per layer: %f > %f', op_name, min_action, self.max_sparsity_per_layer[op_name])
warn_msg = f'[{op_name}] min action > max sparsity per layer: {min_action} > {self.max_sparsity_per_layer[op_name]}'
_logger.warning(warn_msg)
action = max(0., min(max_action, max(min_action, action)))
self.current_op_name = op_name
......
......@@ -51,7 +51,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
self.total_iteration = total_iteration
self.skip_first_iteration = skip_first_iteration
super().__init__(origin_model, origin_config_list=origin_config_list, origin_masks=origin_masks,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='latest')
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_iteration = 1 if self.skip_first_iteration else 0
......@@ -78,10 +78,14 @@ class FunctionBasedTaskGenerator(TaskGenerator):
# get current2origin_sparsity and compact2origin_sparsity
origin_model = torch.load(self._origin_model_path)
current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.target_sparsity)
_logger.debug('\nTask %s total real sparsity compared with original model is:\n%s', str(task_result.task_id), json_tricks.dumps(current2origin_sparsity, indent=4))
current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(origin_model, compact_model, compact_model_masks,
self.target_sparsity)
debug_msg = f'\nTask {task_result.task_id} total real sparsity compared with original model is:\n' + \
f'{json_tricks.dumps(current2origin_sparsity, indent=4)}'
_logger.debug(debug_msg)
if task_result.task_id != 'origin':
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
task = self._tasks[task_result.task_id]
task.state['current2origin_sparsity'] = current2origin_sparsity
# if reach the total_iteration, no more task will be generated
if self.current_iteration > self.total_iteration:
......@@ -116,7 +120,8 @@ class AGPTaskGenerator(FunctionBasedTaskGenerator):
for target, mo in zip(target_sparsity, compact2origin_sparsity):
ori_sparsity = (1 - (1 - iteration / self.total_iteration) ** 3) * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
assert 0 <= sparsity <= 1, err_msg
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
......@@ -128,7 +133,8 @@ class LinearTaskGenerator(FunctionBasedTaskGenerator):
for target, mo in zip(target_sparsity, compact2origin_sparsity):
ori_sparsity = iteration / self.total_iteration * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
assert 0 <= sparsity <= 1, err_msg
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
......@@ -149,16 +155,18 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
# The following is the formula in paper.
# ori_sparsity = (target['total_sparsity'] * 100) ** (iteration / self.total_iteration) / 100
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
assert 0 <= sparsity <= 1, err_msg
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
class SimulatedAnnealingTaskGenerator(TaskGenerator):
def __init__(self, origin_model: Optional[Module], origin_config_list: Optional[List[Dict]], origin_masks: Dict[str, Dict[str, Tensor]] = {},
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
def __init__(self, origin_model: Optional[Module], origin_config_list: Optional[List[Dict]],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, start_temperature: float = 100, stop_temperature: float = 20,
cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35, log_dir: Union[str, Path] = '.',
keep_intermediate_result: bool = False):
"""
Parameters
----------
......@@ -188,7 +196,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
self.perturbation_magnitude = perturbation_magnitude
super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='maximize')
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_temperature = self.start_temperature
......@@ -196,7 +204,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
# TODO: replace with validation here
for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.')
warn_msg = 'Only `total_sparsity` can be differentially allocated sparse ratio to each layer, ' + \
'`sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. ' + \
'Make sure you know what this will lead to, otherwise please use `total_sparsity`.'
_logger.warning(warn_msg)
self.weights_numel, self.masked_rate = get_model_weights_numel(model, config_list, masks)
self.target_sparsity_list = config_list_canonical(model, config_list)
......@@ -259,11 +270,11 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
num_weights = sorted([self.weights_numel[op_name] for op_name in op_names])
sparsity = sorted(random_sparsity)
# calculate the scale
total_weights = np.sum(num_weights)
total_weights_pruned = np.sum([int(num_weight * sparsity[idx]) for idx, num_weight in enumerate(num_weights)])
if total_weights_pruned == 0:
return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment