Unverified Commit ed455174 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] evaluator - step2 (#4992)

parent a689e619
...@@ -5,3 +5,9 @@ Quickstart ...@@ -5,3 +5,9 @@ Quickstart
PyTorch </tutorials/hpo_quickstart_pytorch/main> PyTorch </tutorials/hpo_quickstart_pytorch/main>
TensorFlow </tutorials/hpo_quickstart_tensorflow/main> TensorFlow </tutorials/hpo_quickstart_tensorflow/main>
.. toctree::
:hidden:
/tutorials/hpo_quickstart_pytorch/index
/tutorials/hpo_quickstart_tensorflow/index
Evaluator
=========
.. _compression-torch-evaluator:
TorchEvaluator
--------------
.. autoclass:: nni.compression.pytorch.TorchEvaluator
.. _compression-lightning-evaluator:
LightningEvaluator
------------------
.. autoclass:: nni.compression.pytorch.LightningEvaluator
...@@ -8,5 +8,6 @@ Compression API Reference ...@@ -8,5 +8,6 @@ Compression API Reference
Quantizer <quantizer> Quantizer <quantizer>
Pruning Speedup <pruning_speedup> Pruning Speedup <pruning_speedup>
Quantization Speedup <quantization_speedup> Quantization Speedup <quantization_speedup>
Evaluator <evaluator>
Compression Utilities <utils> Compression Utilities <utils>
Framework Related <framework> Framework Related <framework>
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .utils import LightningEvaluator, TorchEvaluator
...@@ -119,7 +119,8 @@ class Compressor: ...@@ -119,7 +119,8 @@ class Compressor:
Detect all modules should be compressed, and save the result in `self._modules_to_compress`. Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method. The model will be instrumented and user should never edit it after calling this method.
""" """
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.' err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
if self._modules_to_compress is None: if self._modules_to_compress is None:
self._modules_to_compress = [] self._modules_to_compress = []
...@@ -146,7 +147,8 @@ class Compressor: ...@@ -146,7 +147,8 @@ class Compressor:
Optional[Dict] Optional[Dict]
The retrieved configuration for this layer, if None, this layer should not be compressed. The retrieved configuration for this layer, if None, this layer should not be compressed.
""" """
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.' err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, err_msg
ret = None ret = None
for config in self.config_list: for config in self.config_list:
...@@ -240,8 +242,10 @@ class Compressor: ...@@ -240,8 +242,10 @@ class Compressor:
Dict[int, List[str]] Dict[int, List[str]]
A dict. The key is the config idx in config_list, the value is the module name list. i.e., {1: ['layer.0', 'layer.2']}. A dict. The key is the config idx in config_list, the value is the module name list. i.e., {1: ['layer.0', 'layer.2']}.
""" """
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.' err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.' assert self.bound_model is not None, err_msg
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, err_msg
self._unwrap_model() self._unwrap_model()
module_groups = {} module_groups = {}
...@@ -323,6 +327,8 @@ class Compressor: ...@@ -323,6 +327,8 @@ class Compressor:
torch.nn.Module torch.nn.Module
model with specified modules compressed. model with specified modules compressed.
""" """
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.' err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.' assert self.bound_model is not None, err_msg
err_msg = 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, err_msg
return self.bound_model return self.bound_model
...@@ -43,8 +43,8 @@ class PrunerModuleWrapper(Module): ...@@ -43,8 +43,8 @@ class PrunerModuleWrapper(Module):
pruning_target_mask_name = '{}_mask'.format(pruning_target_name) pruning_target_mask_name = '{}_mask'.format(pruning_target_name)
pruning_target = getattr(self.module, pruning_target_name, None) pruning_target = getattr(self.module, pruning_target_name, None)
if hasattr(self.module, pruning_target_name) and pruning_target is not None: if hasattr(self.module, pruning_target_name) and pruning_target is not None:
setattr(self, pruning_target_name, Parameter(torch.empty(pruning_target.shape))) setattr(self, pruning_target_name, Parameter(torch.empty_like(pruning_target)))
self.register_buffer(pruning_target_mask_name, torch.ones(pruning_target.shape)) self.register_buffer(pruning_target_mask_name, torch.ones_like(pruning_target))
else: else:
self.register_buffer(pruning_target_mask_name, None) self.register_buffer(pruning_target_mask_name, None)
...@@ -67,11 +67,11 @@ class PrunerModuleWrapper(Module): ...@@ -67,11 +67,11 @@ class PrunerModuleWrapper(Module):
The best place to call this function is in `Pruner._unwrap_model()`. The best place to call this function is in `Pruner._unwrap_model()`.
""" """
delattr(self.module, 'weight') delattr(self.module, 'weight')
self.module.weight = Parameter(torch.empty(self.weight.size())) self.module.weight = Parameter(torch.empty_like(self.weight))
self.module.weight.data = torch.mul(self.weight, self.weight_mask) self.module.weight.data = torch.mul(self.weight, self.weight_mask)
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
delattr(self.module, 'bias') delattr(self.module, 'bias')
self.module.bias = Parameter(torch.empty(self.bias.size())) self.module.bias = Parameter(torch.empty_like(self.bias))
self.module.bias.data = torch.mul(self.bias, self.bias_mask) self.module.bias.data = torch.mul(self.bias, self.bias_mask)
def forward(self, *inputs): def forward(self, *inputs):
...@@ -130,7 +130,8 @@ class Pruner(Compressor): ...@@ -130,7 +130,8 @@ class Pruner(Compressor):
Wrap all modules that needed to be compressed. Wrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper. Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
""" """
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.' err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
if not self.is_wrapped: if not self.is_wrapped:
for _, wrapper in reversed(list(self.get_modules_wrapper().items())): for _, wrapper in reversed(list(self.get_modules_wrapper().items())):
...@@ -143,7 +144,8 @@ class Pruner(Compressor): ...@@ -143,7 +144,8 @@ class Pruner(Compressor):
Unwrap all modules that needed to be compressed. Unwrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module. Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
""" """
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.' err_msg = 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.bound_model is not None, err_msg
if self.is_wrapped: if self.is_wrapped:
for wrapper in self.get_modules_wrapper().values(): for wrapper in self.get_modules_wrapper().values():
...@@ -165,8 +167,10 @@ class Pruner(Compressor): ...@@ -165,8 +167,10 @@ class Pruner(Compressor):
self._unwrap_model() self._unwrap_model()
parameter_name_map = {} parameter_name_map = {}
for name, param in self.bound_model.named_parameters(): for name, param in self.bound_model.named_parameters():
# If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`, the name will not change after wrap. # If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`,
# If the parameter name in under wrapped module is others, the name `xxx.param` will change to `xxx.module.param` after wrap. # the name will not change after wrap.
# If the parameter name in under wrapped module is others,
# the name `xxx.param` will change to `xxx.module.param` after wrap.
parameter_name_map[name] = wrapped_param_names[id(param)] if id(param) in wrapped_param_names else name parameter_name_map[name] = wrapped_param_names[id(param)] if id(param) in wrapped_param_names else name
self._wrap_model() self._wrap_model()
return parameter_name_map return parameter_name_map
...@@ -183,14 +187,12 @@ class Pruner(Compressor): ...@@ -183,14 +187,12 @@ class Pruner(Compressor):
The masks dict with format {'op_name': {'weight': mask, 'bias': mask}}. The masks dict with format {'op_name': {'weight': mask, 'bias': mask}}.
""" """
wrappers = self.get_modules_wrapper() wrappers = self.get_modules_wrapper()
for name, layer_mask in masks.items(): for module_name, target_masks in masks.items():
assert name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(name) assert module_name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(module_name)
if layer_mask.get('weight') is not None: for target_name, target_mask in target_masks.items():
assert hasattr(wrappers[name], 'weight_mask'), 'There is no attribute weight_mask in wrapper.' assert hasattr(wrappers[module_name], f'{target_name}_mask'), f'There is no attribute {target_name}_mask in wrapper.'
setattr(wrappers[name], 'weight_mask', layer_mask.get('weight')) target: Tensor = getattr(self.get_modules_wrapper()[module_name], target_name)
if layer_mask.get('bias') is not None: setattr(wrappers[module_name], f'{target_name}_mask', target_mask.to(target.device))
assert hasattr(wrappers[name], 'bias_mask'), 'There is no attribute bias_mask in wrapper.'
setattr(wrappers[name], 'bias_mask', layer_mask.get('bias'))
def compress(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]]]: def compress(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]]]:
""" """
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Dict, List, Callable, Optional, cast from typing import Dict, List, Callable, Optional, cast, overload
import json_tricks import json_tricks
import torch import torch
...@@ -11,12 +13,13 @@ from torch import Tensor ...@@ -11,12 +13,13 @@ from torch import Tensor
from torch.nn import Module from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity, config_list_canonical
from nni.compression.pytorch.utils import count_flops_params from nni.compression.pytorch.utils import count_flops_params
from .iterative_pruner import IterativePruner, PRUNER_DICT from .iterative_pruner import IterativePruner, PRUNER_DICT
from .tools import TaskGenerator from .tools import TaskGenerator
from .tools.rl_env import DDPG, AMCEnv from .tools.rl_env import DDPG, AMCEnv
from ..utils import LightningEvaluator, TorchEvaluator, compute_sparsity, config_list_canonical
from ..utils.docstring import _EVALUATOR_DOCSTRING
class AMCTaskGenerator(TaskGenerator): class AMCTaskGenerator(TaskGenerator):
...@@ -41,8 +44,8 @@ class AMCTaskGenerator(TaskGenerator): ...@@ -41,8 +44,8 @@ class AMCTaskGenerator(TaskGenerator):
ddpg_params ddpg_params
The ddpg agent parameters. The ddpg agent parameters.
target : str target : str
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse, but in AMC, you can choose flops sparse. 'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse,
This parameter is used to explain what the sparsity setting in config_list refers to. but in AMC, you can choose flops sparse. This parameter is used to explain what the sparsity setting in config_list refers to.
""" """
def __init__(self, total_episode: int, dummy_input: Tensor, origin_model: Module, origin_config_list: List[Dict], def __init__(self, total_episode: int, dummy_input: Tensor, origin_model: Module, origin_config_list: List[Dict],
...@@ -56,7 +59,7 @@ class AMCTaskGenerator(TaskGenerator): ...@@ -56,7 +59,7 @@ class AMCTaskGenerator(TaskGenerator):
self.config_list_copy = deepcopy(origin_config_list) self.config_list_copy = deepcopy(origin_config_list)
super().__init__(origin_model=origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list, super().__init__(origin_model=origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result) log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='maximize')
def init_pending_tasks(self) -> List[Task]: def init_pending_tasks(self) -> List[Task]:
origin_model = torch.load(self._origin_model_path) origin_model = torch.load(self._origin_model_path)
...@@ -82,6 +85,8 @@ class AMCTaskGenerator(TaskGenerator): ...@@ -82,6 +85,8 @@ class AMCTaskGenerator(TaskGenerator):
return self.generate_tasks(task_result) return self.generate_tasks(task_result)
def generate_tasks(self, task_result: TaskResult) -> List[Task]: def generate_tasks(self, task_result: TaskResult) -> List[Task]:
self.temp_config_list = self.temp_config_list if hasattr(self, 'temp_config_list') else []
# append experience & update agent policy # append experience & update agent policy
if self.action is not None: if self.action is not None:
action, reward, observation, done = self.env.step(self.action, task_result.compact_model) action, reward, observation, done = self.env.step(self.action, task_result.compact_model)
...@@ -106,7 +111,8 @@ class AMCTaskGenerator(TaskGenerator): ...@@ -106,7 +111,8 @@ class AMCTaskGenerator(TaskGenerator):
origin_model = torch.load(self._origin_model_path) origin_model = torch.load(self._origin_model_path)
compact_model = task_result.compact_model compact_model = task_result.compact_model
compact_model_masks = task_result.compact_model_masks compact_model_masks = task_result.compact_model_masks
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.temp_config_list) current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks,
self.temp_config_list) # type: ignore
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.config_list_copy) current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.config_list_copy)
self._tasks[task_result.task_id].state['current_total_sparsity'] = current2origin_sparsity self._tasks[task_result.task_id].state['current_total_sparsity'] = current2origin_sparsity
...@@ -162,7 +168,7 @@ class AMCTaskGenerator(TaskGenerator): ...@@ -162,7 +168,7 @@ class AMCTaskGenerator(TaskGenerator):
class AMCPruner(IterativePruner): class AMCPruner(IterativePruner):
r""" __doc__ = r"""
AMC pruner leverages reinforcement learning to provide the model compression policy. AMC pruner leverages reinforcement learning to provide the model compression policy.
According to the author, this learning-based compression policy outperforms conventional rule-based compression policy by having a higher compression ratio, According to the author, this learning-based compression policy outperforms conventional rule-based compression policy by having a higher compression ratio,
better preserving the accuracy and freeing human labor. better preserving the accuracy and freeing human labor.
...@@ -186,10 +192,11 @@ class AMCPruner(IterativePruner): ...@@ -186,10 +192,11 @@ class AMCPruner(IterativePruner):
- op_names : Operation name to be pruned. - op_names : Operation name to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI. - op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning. - exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
dummy_input : torch.Tensor evaluator
`dummy_input` is required for speedup and tracing the model in RL environment. ``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
evaluator : Callable[[Module], float] {evaluator_docstring}
Evaluate the pruned model and give a score. The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
pruning_algorithm : str pruning_algorithm : str
Supported pruning algorithm ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo']. Supported pruning algorithm ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration. This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
...@@ -197,8 +204,6 @@ class AMCPruner(IterativePruner): ...@@ -197,8 +204,6 @@ class AMCPruner(IterativePruner):
The log directory use to saving the result, you can find the best result under this folder. The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration. If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
ddpg_params : Dict ddpg_params : Dict
Configuration dict to configure the DDPG agent, any key unset will be set to default implicitly. Configuration dict to configure the DDPG agent, any key unset will be set to default implicitly.
- hidden1: hidden num of first fully connect layer. Default: 300 - hidden1: hidden num of first fully connect layer. Default: 300
...@@ -223,23 +228,42 @@ class AMCPruner(IterativePruner): ...@@ -223,23 +228,42 @@ class AMCPruner(IterativePruner):
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse, but in AMC, you can choose flops sparse. 'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse, but in AMC, you can choose flops sparse.
This parameter is used to explain what the sparsity setting in config_list refers to. This parameter is used to explain what the sparsity setting in config_list refers to.
Examples Notes
-------- -----
>>> from nni.compression.pytorch.pruning import AMCPruner
>>> config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.5, 'max_sparsity_per_layer': 0.8}]
>>> dummy_input = torch.rand(...).to(device)
>>> evaluator = ...
>>> finetuner = ...
>>> pruner = AMCPruner(400, model, config_list, dummy_input, evaluator, finetuner=finetuner)
>>> pruner.compress()
The full script can be found :githublink:`here <examples/model_compress/pruning/amc_pruning_torch.py>`. The full script can be found :githublink:`here <examples/model_compress/pruning/amc_pruning_torch.py>`.
""" """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator,
pruning_algorithm: str = 'l1', log_dir: str = '.', keep_intermediate_result: bool = False,
ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'):
...
@overload
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], dummy_input: Tensor, def __init__(self, total_episode: int, model: Module, config_list: List[Dict], dummy_input: Tensor,
evaluator: Callable[[Module], float], pruning_algorithm: str = 'l1', log_dir: str = '.', evaluator: Callable[[Module], float], pruning_algorithm: str = 'l1', log_dir: str = '.',
keep_intermediate_result: bool = False, finetuner: Optional[Callable[[Module], None]] = None, keep_intermediate_result: bool = False, finetuner: Optional[Callable[[Module], None]] = None,
ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'): ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'):
...
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], *args, **kwargs):
new_api = ['evaluator', 'pruning_algorithm', 'log_dir', 'keep_intermediate_result', 'ddpg_params', 'pruning_params', 'target']
new_init_kwargs = {'pruning_algorithm': 'l1', 'log_dir': '.', 'keep_intermediate_result': False,
'ddpg_params': {}, 'pruning_params': {}, 'target': 'flops'}
old_api = ['dummy_input', 'evaluator', 'pruning_algorithm', 'log_dir', 'keep_intermediate_result', 'finetuner', 'ddpg_params',
'pruning_params', 'target']
old_init_kwargs = {'pruning_algorithm': 'l1', 'log_dir': '.', 'keep_intermediate_result': False, 'finetuner': None,
'ddpg_params': {}, 'pruning_params': {}, 'target': 'flops'}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
pruning_algorithm = init_kwargs['pruning_algorithm']
log_dir = init_kwargs['log_dir']
keep_intermediate_result = init_kwargs['keep_intermediate_result']
ddpg_params = init_kwargs['ddpg_params']
pruning_params = init_kwargs['pruning_params']
target = init_kwargs['target']
dummy_input = self.dummy_input if not self.using_evaluator else self.evaluator.get_dummy_input()
assert pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo'], \ assert pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo'], \
"Only support pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo']" "Only support pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo']"
task_generator = AMCTaskGenerator(total_episode=total_episode, task_generator = AMCTaskGenerator(total_episode=total_episode,
...@@ -251,5 +275,9 @@ class AMCPruner(IterativePruner): ...@@ -251,5 +275,9 @@ class AMCPruner(IterativePruner):
ddpg_params=ddpg_params, ddpg_params=ddpg_params,
target=target) target=target)
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=True, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False) if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=True, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=True, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=False) # type: ignore
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, List, Callable, Optional from typing import Dict, List, Callable, Optional, overload
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
from .basic_pruner import ADMMPruner from .basic_pruner import ADMMPruner
from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner
from .tools import LotteryTicketTaskGenerator from .tools import LotteryTicketTaskGenerator
from ..utils import LightningEvaluator, TorchEvaluator, OptimizerConstructHelper
from ..utils.docstring import _EVALUATOR_DOCSTRING
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -21,10 +23,7 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator): ...@@ -21,10 +23,7 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict], def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, sa_params: Dict = {}, log_dir: str = '.', origin_masks: Dict[str, Dict[str, Tensor]] = {}, sa_params: Dict = {}, log_dir: str = '.',
keep_intermediate_result: bool = False): keep_intermediate_result: bool = False):
self.iterative_pruner = SimulatedAnnealingPruner(model=None, self._sa_params = sa_params
config_list=None,
log_dir=Path(log_dir, 'SA'),
**sa_params)
super().__init__(total_iteration=total_iteration, super().__init__(total_iteration=total_iteration,
origin_model=origin_model, origin_model=origin_model,
origin_config_list=origin_config_list, origin_config_list=origin_config_list,
...@@ -36,12 +35,20 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator): ...@@ -36,12 +35,20 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
# TODO: replace with validation here # TODO: replace with validation here
for config in config_list: for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config: if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.') warn_msg = 'Only `total_sparsity` can be differentially allocated sparse ratio to each layer, ' + \
'`sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. ' + \
'Make sure you know what this will lead to, otherwise please use `total_sparsity`.'
_logger.warning(warn_msg)
return super().reset(model, config_list, masks) return super().reset(model, config_list, masks)
def _iterative_pruner_reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}): def _iterative_pruner_reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.iterative_pruner.task_generator._log_dir = Path(self._log_dir_root, 'SA') if not hasattr(self, 'iterative_pruner'):
self.iterative_pruner.reset(model, config_list=config_list, masks=masks) self.iterative_pruner = SimulatedAnnealingPruner(model=model,
config_list=config_list,
log_dir=Path(self._log_dir_root, 'SA'),
**self._sa_params)
else:
self.iterative_pruner.reset(model, config_list=config_list, masks=masks)
def allocate_sparsity(self, new_config_list: List[Dict], model: Module, masks: Dict[str, Dict[str, Tensor]]): def allocate_sparsity(self, new_config_list: List[Dict], model: Module, masks: Dict[str, Dict[str, Tensor]]):
self._iterative_pruner_reset(model, new_config_list, masks) self._iterative_pruner_reset(model, new_config_list, masks)
...@@ -53,8 +60,9 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator): ...@@ -53,8 +60,9 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
class AutoCompressPruner(IterativePruner): class AutoCompressPruner(IterativePruner):
r""" __doc__ = r"""
For total iteration number :math:`N`, AutoCompressPruner prune the model that survive the previous iteration for a fixed sparsity ratio (e.g., :math:`1-{(1-0.8)}^{(1/N)}`) to achieve the overall sparsity (e.g., :math:`0.8`): For total iteration number :math:`N`, AutoCompressPruner prune the model that survive the previous iteration for a fixed sparsity ratio (e.g., :math:`1-{(1-0.8)}^{(1/N)}`) to achieve the overall sparsity (e.g., :math:`0.8`):
""" + r"""
.. code-block:: bash .. code-block:: bash
...@@ -65,35 +73,27 @@ class AutoCompressPruner(IterativePruner): ...@@ -65,35 +73,27 @@ class AutoCompressPruner(IterativePruner):
Parameters Parameters
---------- ----------
model : Module model
The origin unwrapped pytorch model to be pruned. The origin unwrapped pytorch model to be pruned.
config_list : List[Dict] config_list
The origin config list provided by the user. The origin config list provided by the user.
total_iteration : int total_iteration
The total iteration number. The total iteration number.
evaluator : Callable[[Module], float] admm_params
Evaluate the pruned model and give a score.
admm_params : Dict
The parameters passed to the ADMMPruner. The parameters passed to the ADMMPruner.
- trainer : Callable[[Module, Optimizer, Callable]. - evaluator : LightningEvaluator or TorchEvaluator.
A callable function used to train model or just inference. Take model, optimizer, criterion as input. The same with the evaluator of AutoCompressPruner input parameter.
The model will be trained or inferenced `training_epochs` epochs.
- traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
- criterion : Callable[[Tensor, Tensor], Tensor].
The criterion function used in trainer. Take model output and target value as input, and return the loss.
- iterations : int. - iterations : int.
The total iteration number in admm pruning algorithm. The total iteration number in admm pruning algorithm.
- training_epochs : int. - training_epochs : int.
The epoch number for training model in each iteration. The epoch number for training model in each iteration.
sa_params : Dict sa_params
The parameters passed to the SimulatedAnnealingPruner. The parameters passed to the SimulatedAnnealingPruner.
- evaluator : Callable[[Module], float]. Required. - evaluator : LightningEvaluator or TorchEvaluator.
Evaluate the pruned model and give a score. The same with the evaluator of AutoCompressPruner input parameter.
- start_temperature : float. Default: `100`. - start_temperature : float. Default: `100`.
Start temperature of the simulated annealing process. Start temperature of the simulated annealing process.
- stop_temperature : float. Default: `20`. - stop_temperature : float. Default: `20`.
...@@ -104,54 +104,50 @@ class AutoCompressPruner(IterativePruner): ...@@ -104,54 +104,50 @@ class AutoCompressPruner(IterativePruner):
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature. Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
- pruning_algorithm : str. Default: `'level'`. - pruning_algorithm : str. Default: `'level'`.
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm']. Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
- pruning_params : Dict. Default: `{}`. - pruning_params : Dict. Default: dict().
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in. If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
log_dir : str log_dir
The log directory used to save the result, you can find the best result under this folder. The log directory used to save the result, you can find the best result under this folder.
keep_intermediate_result : bool keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration. If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]] evaluator
The finetuner handles all finetune logic, takes a pytorch module as input. ``evaluator`` is used to replace the previous ``finetuner``, ``dummy_input`` and old ``evaluator`` API.
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration. {evaluator_docstring}
speedup : bool The old API (``finetuner``, ``dummy_input`` and old ``evaluator``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact. If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speedup` is True, `dummy_input` is required for tracing the model in speedup.
Examples
--------
>>> import nni
>>> from nni.compression.pytorch.pruning import AutoCompressPruner
>>> model = ...
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
>>> trainer = ...
>>> criterion = ...
>>> evaluator = ...
>>> finetuner = ...
>>> admm_params = {
>>> 'trainer': trainer,
>>> 'traced_optimizer': traced_optimizer,
>>> 'criterion': criterion,
>>> 'iterations': 10,
>>> 'training_epochs': 1
>>> }
>>> sa_params = {
>>> 'evaluator': evaluator
>>> }
>>> pruner = AutoCompressPruner(model, config_list, 10, admm_params, sa_params, finetuner=finetuner)
>>> pruner.compress()
>>> _, model, masks, _, _ = pruner.get_best_result()
Notes
-----
The full script can be found :githublink:`here <examples/model_compress/pruning/auto_compress_pruner.py>`. The full script can be found :githublink:`here <examples/model_compress/pruning/auto_compress_pruner.py>`.
""" """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict, def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False, sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False,
dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None): dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None):
...
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
*args, **kwargs):
new_api = ['evaluator', 'speedup']
new_init_kwargs = {'evaluator': None, 'speedup': False}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False}
init_kwargs = self._init_evaluator(model, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs)
speedup = init_kwargs['speedup']
task_generator = AutoCompressTaskGenerator(total_iteration=total_iteration, task_generator = AutoCompressTaskGenerator(total_iteration=total_iteration,
origin_model=model, origin_model=model,
origin_config_list=config_list, origin_config_list=config_list,
...@@ -175,6 +171,10 @@ class AutoCompressPruner(IterativePruner): ...@@ -175,6 +171,10 @@ class AutoCompressPruner(IterativePruner):
else: else:
admm_params['granularity'] = 'fine-grained' admm_params['granularity'] = 'fine-grained'
pruner = ADMMPruner(None, None, **admm_params) pruner = ADMMPruner(None, None, **admm_params) # type: ignore
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False) if self.using_evaluator:
super().__init__(pruner, task_generator, evaluator=self.evaluator, speedup=speedup, reset_weight=False)
else:
super().__init__(pruner, task_generator, finetuner=self.finetuner, speedup=speedup, dummy_input=self.dummy_input,
evaluator=self._evaluator, reset_weight=False) # type: ignore
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy from copy import deepcopy
from typing import Dict, List, Tuple, Callable, Optional, Union import logging
from typing import Any, Dict, List, Tuple, Callable, Optional, Union, overload
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -12,9 +15,63 @@ from nni.algorithms.compression.v2.pytorch.base import Pruner, BasePruningSchedu ...@@ -12,9 +15,63 @@ from nni.algorithms.compression.v2.pytorch.base import Pruner, BasePruningSchedu
from nni.compression.pytorch.speedup import ModelSpeedup from nni.compression.pytorch.speedup import ModelSpeedup
from .tools import TaskGenerator from .tools import TaskGenerator
from ..utils import Evaluator, LightningEvaluator, TorchEvaluator
_logger = logging.getLogger(__name__)
_LEGACY_FINETUNER = Callable[[Module], None]
_LEGACY_EVALUATOR = Callable[[Module], float]
class PruningScheduler(BasePruningScheduler): # TODO: remove in nni v3.0.
class EvaluatorBasedPruningScheduler(BasePruningScheduler):
evaluator: LightningEvaluator | TorchEvaluator
using_evaluator: bool
finetuner: _LEGACY_FINETUNER
_evaluator: _LEGACY_EVALUATOR
dummy_input: Any
def _init_evaluator(self, model: Module, new_api: List[str], new_init_kwargs: Dict, old_api: List[str],
old_init_kwargs: Dict, args: Tuple, kwargs: Dict) -> Dict:
# for fake __init__ overload, parsing args and kwargs,
# initializing evaluator or [finetuner, evaluator, dummy_input], return the remaining arguments.
if (len(args) > 0 and isinstance(args[0], Evaluator)) or \
(len(args) == 0 and isinstance(kwargs.get('evaluator', None), Evaluator)):
init_kwargs = self._parse_args(new_api, args, kwargs, new_init_kwargs)
self.evaluator: LightningEvaluator | TorchEvaluator = init_kwargs.pop('evaluator')
if not self.evaluator._initialization_complete:
self.evaluator._init_optimizer_helpers(model) # type: ignore
self.using_evaluator = True
else:
init_kwargs = self._parse_args(old_api, args, kwargs, old_init_kwargs)
self.finetuner: _LEGACY_FINETUNER = init_kwargs.pop('finetuner')
self._evaluator: _LEGACY_EVALUATOR = init_kwargs.pop('evaluator')
self.dummy_input = init_kwargs.pop('dummy_input')
self.using_evaluator = False
warn_msg = f'The old API ...{",".join(old_api)} will be deprecated after NNI v3.0,' +\
f'please using the new one ...{",".join(new_api)}'
_logger.warning(warn_msg)
return init_kwargs
def _parse_args(self, arg_names: List, args: Tuple, kwargs: Dict, def_kwargs: Dict) -> Dict:
merged_kwargs = {arg_names[idx]: arg for idx, arg in enumerate(args)}
for key, value in kwargs.items():
if key in merged_kwargs:
raise TypeError(f"{self.__class__.__name__}.__init__() got multiple values for argument '{key}'")
merged_kwargs[key] = value
for key, value in def_kwargs.items():
if key not in merged_kwargs:
merged_kwargs[key] = value
diff = set(arg_names).difference(merged_kwargs.keys())
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
diff = set(merged_kwargs.keys()).difference(arg_names)
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() got {len(diff)} unexpected keyword argument: {diff}")
return merged_kwargs
class PruningScheduler(EvaluatorBasedPruningScheduler):
""" """
Parameters Parameters
---------- ----------
...@@ -25,7 +82,8 @@ class PruningScheduler(BasePruningScheduler): ...@@ -25,7 +82,8 @@ class PruningScheduler(BasePruningScheduler):
Used to generate task for each iteration. Used to generate task for each iteration.
finetuner finetuner
The finetuner handled all finetune logic, use a pytorch module as input. The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise. It will be called at the end of each iteration if reset_weight is False,
will be called at the beginning of each iteration otherwise.
speedup speedup
If set True, speedup the model at the end of each iteration to make the pruned model compact. If set True, speedup the model at the end of each iteration to make the pruned model compact.
dummy_input dummy_input
...@@ -36,16 +94,30 @@ class PruningScheduler(BasePruningScheduler): ...@@ -36,16 +94,30 @@ class PruningScheduler(BasePruningScheduler):
reset_weight reset_weight
If set True, the model weight will reset to the origin model weight at the end of each iteration step. If set True, the model weight will reset to the origin model weight at the end of each iteration step.
""" """
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Optional[Callable[[Module], None]] = None,
speedup: bool = False, dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None, @overload
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, evaluator: LightningEvaluator | TorchEvaluator,
speedup: bool = False, reset_weight: bool = False):
...
@overload
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: _LEGACY_FINETUNER | None = None,
speedup: bool = False, dummy_input: Optional[Tensor] = None, evaluator: _LEGACY_EVALUATOR | None = None,
reset_weight: bool = False): reset_weight: bool = False):
...
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, *args, **kwargs) -> None:
# TODO: remove in nni v3.0. Fake overload.
new_api = ['evaluator', 'speedup', 'reset_weight']
new_init_kwargs = {'evaluator': None, 'speedup': False, 'reset_weight': False}
old_api = ['finetuner', 'speedup', 'dummy_input', 'evaluator', 'reset_weight']
old_init_kwargs = {'finetuner': None, 'evaluator': None, 'dummy_input': None, 'speedup': False, 'reset_weight': False}
init_kwargs = self._init_evaluator(None, new_api, new_init_kwargs, old_api, old_init_kwargs, args, kwargs) # type: ignore
self.pruner = pruner self.pruner = pruner
self.task_generator = task_generator self.task_generator = task_generator
self.finetuner = finetuner self.speedup = init_kwargs['speedup']
self.speedup = speedup self.reset_weight = init_kwargs['reset_weight']
self.dummy_input = dummy_input
self.evaluator = evaluator
self.reset_weight = reset_weight
def reset(self, model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}): def reset(self, model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}):
self.task_generator.reset(model, config_list, masks) self.task_generator.reset(model, config_list, masks)
...@@ -61,6 +133,7 @@ class PruningScheduler(BasePruningScheduler): ...@@ -61,6 +133,7 @@ class PruningScheduler(BasePruningScheduler):
generate masks -> speedup -> finetune -> evaluate generate masks -> speedup -> finetune -> evaluate
""" """
model, masks, config_list = task.load_data() model, masks, config_list = task.load_data()
self.pruner.reset(model, config_list) self.pruner.reset(model, config_list)
self.pruner.load_masks(masks) self.pruner.load_masks(masks)
...@@ -74,28 +147,58 @@ class PruningScheduler(BasePruningScheduler): ...@@ -74,28 +147,58 @@ class PruningScheduler(BasePruningScheduler):
# speedup # speedup
if self.speedup and task.speedup: if self.speedup and task.speedup:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model() if self.using_evaluator:
compact_model_masks = {} ModelSpeedup(compact_model, self.evaluator.get_dummy_input(), pruner_generated_masks).speedup_model()
compact_model_masks = {}
else:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
# finetune # finetune
if self.finetuner is not None and task.finetune: if self.using_evaluator:
if self.speedup: if task.finetune:
self.finetuner(compact_model) self.evaluator.bind_model(compact_model) # type: ignore
else: if self.speedup:
self.pruner._wrap_model() self.evaluator.finetune()
self.finetuner(compact_model) else:
self.pruner._unwrap_model() self.pruner._wrap_model()
self.evaluator.finetune()
self.pruner._unwrap_model()
self.evaluator.unbind_model()
else:
if self.finetuner is not None and task.finetune:
if self.speedup:
self.finetuner(compact_model)
else:
self.pruner._wrap_model()
self.finetuner(compact_model)
self.pruner._unwrap_model()
# evaluate # evaluate
if self.evaluator is not None and task.evaluate: if self.using_evaluator:
if self.speedup: if task.evaluate:
score = self.evaluator(compact_model) self.evaluator.bind_model(compact_model) # type: ignore
# TODO: support saving customized score
if self.speedup:
score = self.evaluator.evaluate()
else:
self.pruner._wrap_model()
score = self.evaluator.evaluate()
self.pruner._unwrap_model()
score = score[0] if isinstance(score, tuple) else score
self.evaluator.unbind_model()
else: else:
self.pruner._wrap_model() score = None
score = self.evaluator(compact_model)
self.pruner._unwrap_model()
else: else:
score = None if self._evaluator is not None and task.evaluate:
if self.speedup:
score = self._evaluator(compact_model) # type: ignore
else:
self.pruner._wrap_model()
score = self._evaluator(compact_model) # type: ignore
self.pruner._unwrap_model()
else:
score = None
# clear model references # clear model references
self.pruner.clear_model_references() self.pruner.clear_model_references()
...@@ -107,13 +210,20 @@ class PruningScheduler(BasePruningScheduler): ...@@ -107,13 +210,20 @@ class PruningScheduler(BasePruningScheduler):
finetune -> generate masks -> reset weight -> speedup -> evaluate finetune -> generate masks -> reset weight -> speedup -> evaluate
""" """
model, masks, config_list = task.load_data() model, masks, config_list = task.load_data()
checkpoint = deepcopy(model.state_dict()) checkpoint = deepcopy(model.state_dict())
self.pruner.reset(model, config_list) self.pruner.reset(model, config_list)
self.pruner.load_masks(masks) self.pruner.load_masks(masks)
# finetune # finetune
if self.finetuner is not None and task.finetune: if self.using_evaluator:
self.finetuner(model) if task.finetune:
self.evaluator.bind_model(model) # type: ignore
self.evaluator.finetune()
self.evaluator.unbind_model()
else:
if self.finetuner is not None and task.finetune:
self.finetuner(model)
# pruning model # pruning model
compact_model, pruner_generated_masks = self.pruner.compress() compact_model, pruner_generated_masks = self.pruner.compress()
...@@ -128,19 +238,38 @@ class PruningScheduler(BasePruningScheduler): ...@@ -128,19 +238,38 @@ class PruningScheduler(BasePruningScheduler):
# speedup # speedup
if self.speedup and task.speedup: if self.speedup and task.speedup:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model() if self.using_evaluator:
compact_model_masks = {} ModelSpeedup(compact_model, self.evaluator.get_dummy_input(), pruner_generated_masks).speedup_model()
compact_model_masks = {}
else:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
# evaluate # evaluate
if self.evaluator is not None and task.evaluate: if self.using_evaluator:
if self.speedup: if task.evaluate:
score = self.evaluator(compact_model) self.evaluator.bind_model(compact_model) # type: ignore
# TODO: support saving customized score
if self.speedup:
score = self.evaluator.evaluate()
else:
self.pruner._wrap_model()
score = self.evaluator.evaluate()
self.pruner._unwrap_model()
score = score[0] if isinstance(score, tuple) else score
self.evaluator.unbind_model()
else: else:
self.pruner._wrap_model() score = None
score = self.evaluator(compact_model)
self.pruner._unwrap_model()
else: else:
score = None if self._evaluator is not None and task.evaluate:
if self.speedup:
score = self._evaluator(compact_model) # type: ignore
else:
self.pruner._wrap_model()
score = self._evaluator(compact_model) # type: ignore
self.pruner._unwrap_model()
else:
score = None
# clear model references # clear model references
self.pruner.clear_model_references() self.pruner.clear_model_references()
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy from copy import deepcopy
import logging import logging
from typing import Dict, List, Tuple, Callable from typing import Dict, List, Tuple, Callable, overload
import torch import torch
from torch import autograd, Tensor from torch import autograd, Tensor
...@@ -12,17 +14,23 @@ from torch.nn.parameter import Parameter ...@@ -12,17 +14,23 @@ from torch.nn.parameter import Parameter
from torch.optim import Optimizer, Adam from torch.optim import Optimizer, Adam
from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper, LayerInfo from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper, LayerInfo
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import EvaluatorBasedPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema, OptimizerConstructHelper from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema
from nni.common.serializer import Traceable
from .tools.base import TrainerBasedDataCollector from .tools.base import EvaluatorBasedDataCollector, TrainerBasedDataCollector
from .tools import ( from .tools import (
StraightMetricsCalculator, NormalSparsityAllocator,
NormalSparsityAllocator StraightMetricsCalculator
)
from ..utils import (
LightningEvaluator,
TorchEvaluator
) )
from ..utils.docstring import _EVALUATOR_DOCSTRING
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -47,8 +55,7 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper): ...@@ -47,8 +55,7 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper):
def forward(self, *inputs): def forward(self, *inputs):
# apply mask to weight, bias # apply mask to weight, bias
# NOTE: I don't know why training getting slower and slower if only `self.weight_mask` without `detach()` self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask)) # type: ignore
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask.detach())) # type: ignore
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
self.module.bias = torch.mul(self.bias, self.bias_mask) # type: ignore self.module.bias = torch.mul(self.bias, self.bias_mask) # type: ignore
return self.module(*inputs) return self.module(*inputs)
...@@ -77,13 +84,30 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector): ...@@ -77,13 +84,30 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion) self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {} data = {}
target_name = 'weight'
for _, wrapper in self.compressor.get_modules_wrapper().items(): for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight_score.data # type: ignore data[wrapper.name] = {target_name: wrapper.weight_score.data} # type: ignore
return data return data
class MovementPruner(BasicPruner): class EvaluatorBasedScoreDataCollector(EvaluatorBasedDataCollector):
r""" """
Collect all weight_score in wrappers as data used to calculate metrics.
"""
def collect(self) -> Dict[str, Tensor]:
assert self.compressor.bound_model is not None
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
data = {}
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target_score: Tensor = getattr(wrapper, f'{target_name}_score')
data[module_name] = {target_name: target_score.data.clone()}
return data
class MovementPruner(EvaluatorBasedPruner):
__doc__ = r"""
Movement pruner is an implementation of movement pruning. Movement pruner is an implementation of movement pruning.
This is a "fine-pruning" algorithm, which means the masks may change during each fine-tuning step. This is a "fine-pruning" algorithm, which means the masks may change during each fine-tuning step.
Each weight element will be scored by the opposite of the sum of the product of weight and its gradient during each step. Each weight element will be scored by the opposite of the sum of the product of weight and its gradient during each step.
...@@ -110,30 +134,12 @@ class MovementPruner(BasicPruner): ...@@ -110,30 +134,12 @@ class MovementPruner(BasicPruner):
- op_names : Operation names to be pruned. - op_names : Operation names to be pruned.
- op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI. - op_partial_names: Operation partial names to be pruned, will be autocompleted by NNI.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning. - exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer : Callable[[Module, Optimizer, Callable]
A callable function used to train model or just inference. Take model, optimizer, criterion as input. evaluator
The model will be trained or inferenced `training_epochs` epochs. ``evaluator`` is used to replace the previous ``trainer``, ``traced_optimizer`` and ``criterion`` API.
{evaluator_docstring}
Example:: The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_epochs : int training_epochs : int
The total epoch number for training the model. The total epoch number for training the model.
Make sure the total `optimizer.step()` in `training_epochs` is bigger than `cool_down_beginning_step`. Make sure the total `optimizer.step()` in `training_epochs` is bigger than `cool_down_beginning_step`.
...@@ -145,33 +151,31 @@ class MovementPruner(BasicPruner): ...@@ -145,33 +151,31 @@ class MovementPruner(BasicPruner):
The sparsity after each `optimizer.step()` is: The sparsity after each `optimizer.step()` is:
total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3). total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3).
Examples Notes
-------- -----
>>> import nni
>>> from nni.compression.pytorch.pruning import MovementPruner
>>> model = ...
>>> # make sure you have used nni.trace to wrap the optimizer class before initialize
>>> traced_optimizer = nni.trace(torch.optim.Adam)(model.parameters())
>>> trainer = ...
>>> criterion = ...
>>> config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
>>> pruner = MovementPruner(model, config_list, trainer, traced_optimizer, criterion, 10, 3000, 27000)
>>> masked_model, masks = pruner.compress()
For detailed example please refer to :githublink:`examples/model_compress/pruning/movement_pruning_glue.py <examples/model_compress/pruning/movement_pruning_glue.py>` For detailed example please refer to :githublink:`examples/model_compress/pruning/movement_pruning_glue.py <examples/model_compress/pruning/movement_pruning_glue.py>`
""" """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, training_epochs: int,
warm_up_step: int, cool_down_beginning_step: int):
...
@overload
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, warm_up_step: int, traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, warm_up_step: int,
cool_down_beginning_step: int): cool_down_beginning_step: int):
self.trainer = trainer ...
if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer def __init__(self, model: Module, config_list: List[Dict], *args, **kwargs):
else: # TODO: remove in nni v3.0. Fake overload.
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer) new_api = ['evaluator', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step']
self.criterion = criterion old_api = ['trainer', 'traced_optimizer', 'criterion', 'training_epochs', 'warm_up_step', 'cool_down_beginning_step']
self.training_epochs = training_epochs init_kwargs = self._init_evaluator(model, new_api, old_api, {}, args, kwargs)
self.warm_up_step = warm_up_step
self.cool_down_beginning_step = cool_down_beginning_step self.training_epochs: int = init_kwargs['training_epochs']
self.warm_up_step: int = init_kwargs['warm_up_step']
self.cool_down_beginning_step: int = init_kwargs['cool_down_beginning_step']
assert self.warm_up_step < self.cool_down_beginning_step, '`warm_up_step` should smaller than `cool_down_beginning_step`' assert self.warm_up_step < self.cool_down_beginning_step, '`warm_up_step` should smaller than `cool_down_beginning_step`'
super().__init__(model, config_list) super().__init__(model, config_list)
...@@ -184,14 +188,16 @@ class MovementPruner(BasicPruner): ...@@ -184,14 +188,16 @@ class MovementPruner(BasicPruner):
if self.warm_up_step < current_step <= self.cool_down_beginning_step: if self.warm_up_step < current_step <= self.cool_down_beginning_step:
wrapper_dict = self.get_modules_wrapper() wrapper_dict = self.get_modules_wrapper()
for config in self.config_list: for config in self.config_list:
current_sparsity = config['total_sparsity'] * (1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3) scale = 1 - (1 - (current_step - self.warm_up_step) / (self.cool_down_beginning_step - self.warm_up_step)) ** 3
current_sparsity = config['total_sparsity'] * scale
for op_name in config['op_names']: for op_name in config['op_names']:
wrapper_dict[op_name].config['total_sparsity'] = current_sparsity wrapper = wrapper_dict[op_name]
wrapper.config['total_sparsity'] = current_sparsity
def reset_tools(self): def reset_tools(self):
if self.metrics_calculator is None: if not hasattr(self, 'metrics_calculator'):
self.metrics_calculator = StraightMetricsCalculator() self.metrics_calculator = StraightMetricsCalculator()
if self.sparsity_allocator is None: if not hasattr(self, 'sparsity_allocator'):
self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False) self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False)
# use Adam to update the weight_score # use Adam to update the weight_score
...@@ -208,16 +214,30 @@ class MovementPruner(BasicPruner): ...@@ -208,16 +214,30 @@ class MovementPruner(BasicPruner):
if self.step_counter > self.warm_up_step: if self.step_counter > self.warm_up_step:
self.cubic_schedule(self.step_counter) self.cubic_schedule(self.step_counter)
data = {} data = {}
target_name = 'weight'
for wrapper in self.get_modules_wrapper().values(): for wrapper in self.get_modules_wrapper().values():
data[wrapper.name] = wrapper.weight_score.data data[wrapper.name] = {target_name: wrapper.weight_score.data}
metrics = self.metrics_calculator.calculate_metrics(data) # type: ignore metrics = self.metrics_calculator.calculate_metrics(data) # type: ignore
masks = self.sparsity_allocator.generate_sparsity(metrics) # type: ignore masks = self.sparsity_allocator.generate_sparsity(metrics) # type: ignore
self.load_masks(masks) self.load_masks(masks)
if self.data_collector is None: if self.using_evaluator:
self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion, self.training_epochs, opt_after_tasks=[_optimizer_patch]) # TODO: move to other place in nni v3.0
self.evaluator.unbind_model()
self.evaluator.bind_model(self.bound_model, self.get_origin2wrapped_parameter_name_map()) # type: ignore
if not hasattr(self, 'data_collector'):
self.data_collector = EvaluatorBasedScoreDataCollector(self, self.evaluator,
after_opt_step_tasks=[_optimizer_patch],
max_epochs=self.training_epochs)
else:
self.data_collector.reset(after_opt_step_tasks=[_optimizer_patch])
else: else:
self.data_collector.reset() if not hasattr(self, 'data_collector'):
self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper,
self.criterion, self.training_epochs,
opt_after_tasks=[_optimizer_patch])
else:
self.data_collector.reset()
def _wrap_modules(self, layer: LayerInfo, config: Dict): def _wrap_modules(self, layer: LayerInfo, config: Dict):
""" """
...@@ -243,7 +263,6 @@ class MovementPruner(BasicPruner): ...@@ -243,7 +263,6 @@ class MovementPruner(BasicPruner):
for wrapper in self.get_modules_wrapper().values(): for wrapper in self.get_modules_wrapper().values():
wrapper.config['total_sparsity'] = 0 wrapper.config['total_sparsity'] = 0
result = super().compress() result = super().compress()
# del weight_score if self.using_evaluator:
for wrapper in self.get_modules_wrapper().values(): self.evaluator.unbind_model()
wrapper.weight_score = None
return result return result
...@@ -8,6 +8,12 @@ from .base import ( ...@@ -8,6 +8,12 @@ from .base import (
SparsityAllocator, SparsityAllocator,
TaskGenerator TaskGenerator
) )
from .data_collector import (
TargetDataCollector,
EvaluatorBasedTargetDataCollector,
EvaluatorBasedHookDataCollector
)
# TODO: remove in nni v3.0.
from .data_collector import ( from .data_collector import (
WeightDataCollector, WeightDataCollector,
WeightTrainerBasedDataCollector, WeightTrainerBasedDataCollector,
...@@ -16,7 +22,7 @@ from .data_collector import ( ...@@ -16,7 +22,7 @@ from .data_collector import (
from .metrics_calculator import ( from .metrics_calculator import (
StraightMetricsCalculator, StraightMetricsCalculator,
NormMetricsCalculator, NormMetricsCalculator,
MultiDataNormMetricsCalculator, HookDataNormMetricsCalculator,
DistMetricsCalculator, DistMetricsCalculator,
APoZRankMetricsCalculator, APoZRankMetricsCalculator,
MeanRankMetricsCalculator MeanRankMetricsCalculator
......
...@@ -6,7 +6,7 @@ from datetime import datetime ...@@ -6,7 +6,7 @@ from datetime import datetime
import logging import logging
from pathlib import Path from pathlib import Path
import types import types
from typing import List, Dict, Tuple, Optional, Callable, Union from typing import List, Dict, Literal, Tuple, Optional, Callable, Union
import json_tricks import json_tricks
import torch import torch
...@@ -15,7 +15,7 @@ from torch.nn import Module ...@@ -15,7 +15,7 @@ from torch.nn import Module
from torch.optim import Optimizer from torch.optim import Optimizer
from ...base import Pruner, LayerInfo, Task, TaskResult from ...base import Pruner, LayerInfo, Task, TaskResult
from ...utils import OptimizerConstructHelper, Scaling from ...utils import Evaluator, Hook, OptimizerConstructHelper, Scaling
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -45,7 +45,7 @@ class DataCollector: ...@@ -45,7 +45,7 @@ class DataCollector:
def __init__(self, compressor: Pruner): def __init__(self, compressor: Pruner):
self.compressor = compressor self.compressor = compressor
def reset(self): def reset(self, *args, **kwargs):
""" """
Reset the `DataCollector`. Reset the `DataCollector`.
""" """
...@@ -63,9 +63,12 @@ class DataCollector: ...@@ -63,9 +63,12 @@ class DataCollector:
raise NotImplementedError() raise NotImplementedError()
# TODO: remove in nni v3.0.
COLLECTOR_TYPE = Union[Callable[[List, Tensor], Callable[[Tensor], None]], Callable[[List], Callable[[Module, Tensor, Tensor], None]]]
class HookCollectorInfo: class HookCollectorInfo:
def __init__(self, targets: Union[Dict[str, Tensor], List[LayerInfo]], hook_type: str, def __init__(self, targets: Union[Dict[str, Tensor], List[LayerInfo]], hook_type: str,
collector: Union[Callable[[List, Tensor], Callable[[Tensor], None]], Callable[[List], Callable[[Module, Tensor, Tensor], None]]]): collector: COLLECTOR_TYPE):
""" """
This class used to aggregate the information of what kind of hook is placed on which layers. This class used to aggregate the information of what kind of hook is placed on which layers.
...@@ -76,23 +79,24 @@ class HookCollectorInfo: ...@@ -76,23 +79,24 @@ class HookCollectorInfo:
hook_type hook_type
'forward' or 'backward'. 'forward' or 'backward'.
collector collector
A hook function generator, the input is a buffer (empty list) or a buffer (empty list) and tensor, the output is a hook function. A hook function generator, the input is a buffer (empty list) or a buffer (empty list) and tensor,
The buffer is used to store the data wanted to hook. the output is a hook function. The buffer is used to store the data wanted to hook.
""" """
self.targets = targets self.targets = targets
self.hook_type = hook_type self.hook_type = hook_type
self.collector = collector self.collector = collector
# TODO: remove in nni v3.0.
class TrainerBasedDataCollector(DataCollector): class TrainerBasedDataCollector(DataCollector):
""" """
This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks. This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks.
""" """
def __init__(self, compressor: Pruner, trainer: Callable[[Module, Optimizer, Callable], None], optimizer_helper: OptimizerConstructHelper, def __init__(self, compressor: Pruner, trainer: Callable[[Module, Optimizer, Callable], None],
criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, optimizer_helper: OptimizerConstructHelper, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
opt_before_tasks: List = [], opt_after_tasks: List = [], opt_before_tasks: List = [], opt_after_tasks: List = [], collector_infos: List[HookCollectorInfo] = [],
collector_infos: List[HookCollectorInfo] = [], criterion_patch: Optional[Callable[[Callable], Callable]] = None): criterion_patch: Optional[Callable[[Callable], Callable]] = None):
""" """
Parameters Parameters
---------- ----------
...@@ -252,6 +256,47 @@ class TrainerBasedDataCollector(DataCollector): ...@@ -252,6 +256,47 @@ class TrainerBasedDataCollector(DataCollector):
self._remove_hook(hook_id) self._remove_hook(hook_id)
class EvaluatorBasedDataCollector(DataCollector):
"""
This data collector is the base class for the data collectors that want to use ``Evaluator`` to train or inference.
Three main usages are supported in this data collector:
1. Doing something before ``optimzer.step()`` and after ``optimzer.step()``. ``before_opt_step_tasks`` is a list of task functions
that will execute before ``optimzer.step()``. ``after_opt_step_tasks`` is a list of task functions that will execute after
``optimzer.step()``. All the task functions in the list should not have input arguments, function return value is allowed,
but ``Evaluator`` will not catch it.
2. Patch or modify the training loss. ``loss_patch`` is a function with input is the original loss and the output is the modified loss.
3. Add hooks on ``torch.nn.Module`` or ``Parameter`` or ``Buffer``. Three kinds of hook are supported, ``TensorHook``, ``ForwardHook``
and ``BackwardHook``. For initializing a ``Hook``, a hook function factory is needed, the factory function's input is an empty list,
and the output is a hook function defined by Pytorch.
Please refer `register_hook <https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html>`_,
`register_forward_hook <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook>`_,
`register_backward_hook <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_backward_hook>`_.
"""
def __init__(self, compressor: Pruner, evaluator: Evaluator, before_opt_step_tasks: List[Callable] | None = None,
after_opt_step_tasks: List[Callable] | None = None, loss_patch: Callable[[Tensor], Tensor] | None = None,
hooks: Dict[str, Dict[str, Hook]] | None = None, max_steps: int | None = None, max_epochs: int | None = None):
super().__init__(compressor)
self.evaluator = evaluator
self.max_steps = max_steps
self.max_epochs = max_epochs
self.reset(before_opt_step_tasks, after_opt_step_tasks, loss_patch, hooks)
def reset(self, before_opt_step_tasks: List[Callable] | None = None, after_opt_step_tasks: List[Callable] | None = None,
loss_patch: Callable[[Tensor], Tensor] | None = None, hooks: Dict[str, Dict[str, Hook]] | None = None):
if before_opt_step_tasks or after_opt_step_tasks:
before_opt_step_tasks = before_opt_step_tasks if before_opt_step_tasks else []
after_opt_step_tasks = after_opt_step_tasks if after_opt_step_tasks else []
self.evaluator.patch_optimizer_step(before_opt_step_tasks, after_opt_step_tasks)
if loss_patch:
self.evaluator.patch_loss(loss_patch)
if hooks:
self._hooks = hooks
hook_list = [hook for _ in hooks.values() for hook in _.values()]
self.evaluator.register_hooks(hook_list)
class MetricsCalculator: class MetricsCalculator:
""" """
An abstract class for calculate a kind of metrics of the given data. An abstract class for calculate a kind of metrics of the given data.
...@@ -260,7 +305,8 @@ class MetricsCalculator: ...@@ -260,7 +305,8 @@ class MetricsCalculator:
---------- ----------
scalers scalers
Scaler is used to scale the metrics' size. It scaling metric to the same size as the shrinked mask in the sparsity allocator. Scaler is used to scale the metrics' size. It scaling metric to the same size as the shrinked mask in the sparsity allocator.
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`. If you want to use different scalers for different pruning targets in different modules,
please use a dict `{module_name: {target_name: scaler}}`.
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask. If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask. If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`. Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
...@@ -268,7 +314,8 @@ class MetricsCalculator: ...@@ -268,7 +314,8 @@ class MetricsCalculator:
""" """
def __init__(self, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None): def __init__(self, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None):
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers \
if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
def _get_scaler(self, module_name: str, target_name: str) -> Scaling: def _get_scaler(self, module_name: str, target_name: str) -> Scaling:
scaler = _get_scaler(self.scalers, module_name, target_name) scaler = _get_scaler(self.scalers, module_name, target_name)
...@@ -301,7 +348,8 @@ class SparsityAllocator: ...@@ -301,7 +348,8 @@ class SparsityAllocator:
scalers scalers
Scaler is used to scale the masks' size. It shrinks the mask of the same size as the pruning target to the same size as the metric, Scaler is used to scale the masks' size. It shrinks the mask of the same size as the pruning target to the same size as the metric,
or expands the mask of the same size as the metric to the same size as the pruning target. or expands the mask of the same size as the metric to the same size as the pruning target.
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`. If you want to use different scalers for different pruning targets in different modules,
please use a dict `{module_name: {target_name: scaler}}`.
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask. If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask. If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`. Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
...@@ -313,7 +361,8 @@ class SparsityAllocator: ...@@ -313,7 +361,8 @@ class SparsityAllocator:
def __init__(self, pruner: Pruner, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None, continuous_mask: bool = True): def __init__(self, pruner: Pruner, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None, continuous_mask: bool = True):
self.pruner = pruner self.pruner = pruner
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers \
if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
self.continuous_mask = continuous_mask self.continuous_mask = continuous_mask
def _get_scaler(self, module_name: str, target_name: str) -> Scaling | None: def _get_scaler(self, module_name: str, target_name: str) -> Scaling | None:
...@@ -335,25 +384,39 @@ class SparsityAllocator: ...@@ -335,25 +384,39 @@ class SparsityAllocator:
mask = (scaler.shrink(mask) != 0).type_as(mask) mask = (scaler.shrink(mask) != 0).type_as(mask)
return mask return mask
def _continuous_mask(self, new_masks: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]: def _mask_metric(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
# Set the already masked part in the metric to the minimum value. # Set the already masked part in the metric to the minimum value.
target_name = 'weight' target_name = 'weight'
for module_name, targets_metric in metrics.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
shrinked_target_mask = self._shrink_mask(module_name, target_name, old_target_mask)
# make sure the masked position has the minimum metric
targets_metric[target_name] = targets_metric[target_name].to(shrinked_target_mask.device)
min_value = targets_metric[target_name].min() - 1
targets_metric[target_name] = torch.where(shrinked_target_mask != 0, targets_metric[target_name], min_value)
return metrics
def _continuous_mask(self, new_masks: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
# Set the already masked part to zero in the new_masks.
target_name = 'weight'
for module_name, target_mask in new_masks.items(): for module_name, target_mask in new_masks.items():
wrapper = self.pruner.get_modules_wrapper()[module_name] wrapper = self.pruner.get_modules_wrapper()[module_name]
old_target_mask = getattr(wrapper, f'{target_name}_mask', None) old_target_mask: Tensor | None = getattr(wrapper, f'{target_name}_mask', None)
if old_target_mask is not None: if old_target_mask is not None:
new_masks[module_name][target_name] = torch.min(target_mask[target_name], old_target_mask) new_masks[module_name][target_name] = torch.min(target_mask[target_name],
old_target_mask.to(target_mask[target_name].device))
return new_masks return new_masks
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]: def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
""" """
Generate masks for metrics-dependent targets. Generate masks for metrics-dependent targets.
Parameters Parameters
---------- ----------
metrics metrics
The format is {module_name: weight_metric}. The format is {module_name: {target_name: target_metric}}.
The metric of `weight` usually has the same size with shrinked mask. The metric of usually has the same size with shrinked mask.
Return Return
------ ------
...@@ -384,7 +447,7 @@ class SparsityAllocator: ...@@ -384,7 +447,7 @@ class SparsityAllocator:
reduce_dims = [reduce_dim for reduce_dim in range(1, len(weight_mask.shape))] reduce_dims = [reduce_dim for reduce_dim in range(1, len(weight_mask.shape))]
# count unmasked number of values on dim 0 (output channel) of weight # count unmasked number of values on dim 0 (output channel) of weight
unmasked_num_on_dim0 = weight_mask.sum(reduce_dims) if reduce_dims else weight_mask unmasked_num_on_dim0 = weight_mask.sum(reduce_dims) if reduce_dims else weight_mask
module_masks['bias'] = (unmasked_num_on_dim0 != 0).type_as(old_bias_mask) module_masks['bias'] = (unmasked_num_on_dim0 != 0).type_as(weight_mask)
return masks return masks
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]: def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
...@@ -401,6 +464,8 @@ class SparsityAllocator: ...@@ -401,6 +464,8 @@ class SparsityAllocator:
Dict[str, Dict[str, Tensor]] Dict[str, Dict[str, Tensor]]
The masks format is {module_name: {target_name: mask}}. The masks format is {module_name: {target_name: mask}}.
""" """
if self.continuous_mask:
metrics = self._mask_metric(metrics)
masks = self.common_target_masks_generation(metrics) masks = self.common_target_masks_generation(metrics)
masks = self.special_target_masks_generation(masks) masks = self.special_target_masks_generation(masks)
if self.continuous_mask: if self.continuous_mask:
...@@ -425,11 +490,22 @@ class TaskGenerator: ...@@ -425,11 +490,22 @@ class TaskGenerator:
The log directory use to saving the task generator log. The log directory use to saving the task generator log.
keep_intermediate_result keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration. If keeping the intermediate result, including intermediate model and masks during each iteration.
best_result_mode
The way to decide which one is the best result. Three modes are supported.
If the task results don't contain scores (task_result.score is None), it will fall back to ``latest``.
1. latest: The newest received result is the best result.
2. maximize: The one with largest task result score is the best result.
3. minimize: The one with smallest task result score is the best result.
""" """
def __init__(self, origin_model: Optional[Module], origin_masks: Optional[Dict[str, Dict[str, Tensor]]] = {}, def __init__(self, origin_model: Optional[Module], origin_masks: Optional[Dict[str, Dict[str, Tensor]]] = {},
origin_config_list: Optional[List[Dict]] = [], log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False): origin_config_list: Optional[List[Dict]] = [], log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
best_result_mode: Literal['latest', 'maximize', 'minimize'] = 'maximize'):
self._log_dir = log_dir self._log_dir = log_dir
self._keep_intermediate_result = keep_intermediate_result self._keep_intermediate_result = keep_intermediate_result
assert best_result_mode in ['latest', 'maximize', 'minimize'], f'Unsupported best_result_mode value: {best_result_mode}'
self._best_result_mode = best_result_mode
if origin_model is not None and origin_config_list is not None and origin_masks is not None: if origin_model is not None and origin_config_list is not None and origin_masks is not None:
self.reset(origin_model, origin_config_list, origin_masks) self.reset(origin_model, origin_config_list, origin_masks)
...@@ -472,13 +548,24 @@ class TaskGenerator: ...@@ -472,13 +548,24 @@ class TaskGenerator:
json_tricks.dump(config_list, f, indent=4) json_tricks.dump(config_list, f, indent=4)
def update_best_result(self, task_result: TaskResult): def update_best_result(self, task_result: TaskResult):
score = task_result.score save_as_best_result = False
task_id = task_result.task_id task = self._tasks[task_result.task_id]
task = self._tasks[task_id] task.score = task_result.score
task.score = score
if self._best_score is None or (score is not None and score > self._best_score): if self._best_result_mode == 'latest':
self._best_score = score self._best_task_id, save_as_best_result = task_result.task_id, True
self._best_task_id = task_id
if self._best_result_mode == 'maximize':
if self._best_score is None or (task.score is not None and task.score > self._best_score):
self._best_score = task.score
self._best_task_id, save_as_best_result = task_result.task_id, True
if self._best_result_mode == 'minimize':
if self._best_score is None or (task.score is not None and task.score < self._best_score):
self._best_score = task.score
self._best_task_id, save_as_best_result = task_result.task_id, True
if save_as_best_result:
with Path(task.config_list_path).open('r') as fr: with Path(task.config_list_path).open('r') as fr:
best_config_list = json_tricks.load(fr) best_config_list = json_tricks.load(fr)
self._save_data('best_result', task_result.compact_model, task_result.compact_model_masks, best_config_list) self._save_data('best_result', task_result.compact_model, task_result.compact_model_masks, best_config_list)
......
...@@ -6,13 +6,16 @@ from typing import Dict, List ...@@ -6,13 +6,16 @@ from typing import Dict, List
from torch import Tensor from torch import Tensor
from .base import DataCollector, TrainerBasedDataCollector from .base import DataCollector, EvaluatorBasedDataCollector
from .base import TrainerBasedDataCollector
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
__all__ = ['WeightDataCollector', 'WeightTrainerBasedDataCollector', 'SingleHookTrainerBasedDataCollector'] __all__ = ['TargetDataCollector', 'EvaluatorBasedTargetDataCollector', 'EvaluatorBasedHookDataCollector',
'WeightDataCollector', 'WeightTrainerBasedDataCollector', 'SingleHookTrainerBasedDataCollector'] # TODO: remove in nni v3.0.
# TODO: remove in nni v3.0.
class WeightDataCollector(DataCollector): class WeightDataCollector(DataCollector):
""" """
Collect all wrapper weights. Collect all wrapper weights.
...@@ -21,40 +24,102 @@ class WeightDataCollector(DataCollector): ...@@ -21,40 +24,102 @@ class WeightDataCollector(DataCollector):
def reset(self): def reset(self):
pass pass
def collect(self) -> Dict[str, Tensor]: def collect(self) -> Dict[str, Dict[str, Tensor]]:
data = {} data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items(): target_name = 'weight'
data[wrapper.name] = wrapper.weight.data for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data return data
# TODO: remove in nni v3.0.
class WeightTrainerBasedDataCollector(TrainerBasedDataCollector): class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
""" """
Collect all wrapper weights after training or inference. Collect all wrapper weights after training or inference.
""" """
def collect(self) -> Dict[str, Tensor]: def collect(self) -> Dict[str, Dict[str, Tensor]]:
assert self.compressor.bound_model is not None assert self.compressor.bound_model is not None
for _ in range(self.training_epochs): for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion) self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {} data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items(): target_name = 'weight'
data[wrapper.name] = wrapper.weight.data for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data return data
# TODO: remove in nni v3.0.
class SingleHookTrainerBasedDataCollector(TrainerBasedDataCollector): class SingleHookTrainerBasedDataCollector(TrainerBasedDataCollector):
""" """
Add hooks and collect data during training or inference. Add hooks and collect data during training or inference.
Single means each wrapper only has one hook to collect data. Single means each wrapper only has one hook to collect data.
""" """
def collect(self) -> Dict[str, List[Tensor]]: def collect(self) -> Dict[str, Dict[str, List[Tensor]]]:
assert self.compressor.bound_model is not None assert self.compressor.bound_model is not None
for _ in range(self.training_epochs): for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion) self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {} data = {}
[data.update(buffer_dict) for _, buffer_dict in self._hook_buffer.items()] target_name = 'weight'
for _, buffer_dict in self._hook_buffer.items():
for module_name, target_data in buffer_dict.items():
data[module_name] = {target_name: target_data}
return data
class TargetDataCollector(DataCollector):
"""
Collect all wrapper targets.
"""
def reset(self):
# No need to reset anything in this data collector.
pass
def collect(self) -> Dict[str, Dict[str, Tensor]]:
data = {}
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
class EvaluatorBasedTargetDataCollector(EvaluatorBasedDataCollector):
"""
Collect all wrapper pruning target after training or inference.
"""
def collect(self) -> Dict[str, Dict[str, Tensor]]:
assert self.compressor.bound_model is not None
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
data = {}
target_name = 'weight'
for module_name, wrapper in self.compressor.get_modules_wrapper().items():
target: Tensor = getattr(wrapper, target_name)
data[module_name] = {target_name: target.data.clone()}
return data
class EvaluatorBasedHookDataCollector(EvaluatorBasedDataCollector):
"""
Add hooks and collect data during training or inference.
NOTE: Only support one target has one hook right now.
"""
def collect(self) -> Dict[str, Dict[str, List]]:
assert self.compressor.bound_model is not None
self.evaluator.train(max_steps=self.max_steps, max_epochs=self.max_epochs)
data = {}
for module_name, hooks in self._hooks.items():
data[module_name] = {}
for target_name, hook in hooks.items():
data[module_name][target_name] = hook.buffer
return data return data
...@@ -11,7 +11,7 @@ from torch import Tensor ...@@ -11,7 +11,7 @@ from torch import Tensor
from .base import MetricsCalculator from .base import MetricsCalculator
from ...utils import Scaling from ...utils import Scaling
__all__ = ['NormMetricsCalculator', 'MultiDataNormMetricsCalculator', 'DistMetricsCalculator', __all__ = ['NormMetricsCalculator', 'HookDataNormMetricsCalculator', 'DistMetricsCalculator',
'APoZRankMetricsCalculator', 'MeanRankMetricsCalculator', 'StraightMetricsCalculator'] 'APoZRankMetricsCalculator', 'MeanRankMetricsCalculator', 'StraightMetricsCalculator']
...@@ -19,11 +19,12 @@ class StraightMetricsCalculator(MetricsCalculator): ...@@ -19,11 +19,12 @@ class StraightMetricsCalculator(MetricsCalculator):
""" """
This metrics calculator directly returns a copy of data as metrics. This metrics calculator directly returns a copy of data as metrics.
""" """
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
metrics = {} metrics = {}
for name, tensor in data.items(): for module_name, targets_data in data.items():
# use inplace detach `detach_` here to avoid creating a new tensor metrics[module_name] = {}
metrics[name] = tensor.clone().detach_() for target_name, target_data in targets_data.items():
metrics[module_name][target_name] = target_data.clone().detach()
return metrics return metrics
...@@ -44,27 +45,32 @@ class NormMetricsCalculator(MetricsCalculator): ...@@ -44,27 +45,32 @@ class NormMetricsCalculator(MetricsCalculator):
super().__init__(scalers=scalers) super().__init__(scalers=scalers)
self.p = p if p is not None else 'fro' self.p = p if p is not None else 'fro'
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor: def reduce_func(t: Tensor) -> Tensor:
return t.norm(p=self.p, dim=-1) # type: ignore return t.norm(p=self.p, dim=-1) # type: ignore
metrics = {} metrics = {}
target_name = 'weight' for module_name, targets_data in data.items():
for module_name, target_data in data.items(): metrics[module_name] = {}
scaler = self._get_scaler(module_name, target_name) for target_name, target_data in targets_data.items():
metrics[module_name] = scaler.shrink(target_data, reduce_func) scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics return metrics
class MultiDataNormMetricsCalculator(NormMetricsCalculator): class HookDataNormMetricsCalculator(NormMetricsCalculator):
""" """
The data value format is a two-element list [batch_number, cumulative_data]. The hook data value format is a two-element list [batch_number, cumulative_data].
Directly use the cumulative_data as new_data to calculate norm metric. Directly use the cumulative_data as new_data to calculate norm metric.
TaylorFO pruner uses this to calculate metric. TaylorFO pruner uses this to calculate metric.
""" """
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]: def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
new_data = {name: buffer[1] for name, buffer in data.items()} new_data = {}
for module_name, targets_data in data.items():
new_data[module_name] = {}
for target_name, (_, target_data) in targets_data.items():
new_data[module_name][target_name] = target_data
return super().calculate_metrics(new_data) return super().calculate_metrics(new_data)
...@@ -85,7 +91,7 @@ class DistMetricsCalculator(MetricsCalculator): ...@@ -85,7 +91,7 @@ class DistMetricsCalculator(MetricsCalculator):
super().__init__(scalers=scalers) super().__init__(scalers=scalers)
self.p = p if p is not None else 'fro' self.p = p if p is not None else 'fro'
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: def calculate_metrics(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor: def reduce_func(t: Tensor) -> Tensor:
reshape_data = t.reshape(-1, t.shape[-1]) reshape_data = t.reshape(-1, t.shape[-1])
metric = torch.zeros(reshape_data.shape[0], device=reshape_data.device) metric = torch.zeros(reshape_data.shape[0], device=reshape_data.device)
...@@ -94,10 +100,11 @@ class DistMetricsCalculator(MetricsCalculator): ...@@ -94,10 +100,11 @@ class DistMetricsCalculator(MetricsCalculator):
return metric.reshape(t.shape[:-1]) return metric.reshape(t.shape[:-1])
metrics = {} metrics = {}
target_name = 'weight' for module_name, targets_data in data.items():
for module_name, target_data in data.items(): metrics[module_name] = {}
scaler = self._get_scaler(module_name, target_name) for target_name, target_data in targets_data.items():
metrics[module_name] = scaler.shrink(target_data, reduce_func) scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics return metrics
...@@ -108,16 +115,18 @@ class APoZRankMetricsCalculator(MetricsCalculator): ...@@ -108,16 +115,18 @@ class APoZRankMetricsCalculator(MetricsCalculator):
Note that the metric we return is (1 - apoz), because we assume a higher metric value has higher importance. Note that the metric we return is (1 - apoz), because we assume a higher metric value has higher importance.
APoZRank pruner uses this to calculate metric. APoZRank pruner uses this to calculate metric.
""" """
def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor: def reduce_func(t: Tensor) -> Tensor:
return 1 - t.mean(dim=-1) return 1 - t.mean(dim=-1)
metrics = {} metrics = {}
target_name = 'weight' for module_name, targets_data in data.items():
for module_name, target_data in data.items(): metrics[module_name] = {}
target_data = target_data[1] / target_data[0] for target_name, target_data in targets_data.items():
scaler = self._get_scaler(module_name, target_name) target_data = target_data[1] / target_data[0]
metrics[module_name] = scaler.shrink(target_data, reduce_func) scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics return metrics
...@@ -127,14 +136,15 @@ class MeanRankMetricsCalculator(MetricsCalculator): ...@@ -127,14 +136,15 @@ class MeanRankMetricsCalculator(MetricsCalculator):
This metric simply calculate the average on `self.dim`, then divide by the batch_number. This metric simply calculate the average on `self.dim`, then divide by the batch_number.
MeanRank pruner uses this to calculate metric. MeanRank pruner uses this to calculate metric.
""" """
def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]: def calculate_metrics(self, data: Dict[str, Dict[str, List[Tensor]]]) -> Dict[str, Dict[str, Tensor]]:
def reduce_func(t: Tensor) -> Tensor: def reduce_func(t: Tensor) -> Tensor:
return t.mean(dim=-1) return t.mean(dim=-1)
metrics = {} metrics = {}
target_name = 'weight' for module_name, targets_data in data.items():
for module_name, target_data in data.items(): metrics[module_name] = {}
target_data = target_data[1] / target_data[0] for target_name, target_data in targets_data.items():
scaler = self._get_scaler(module_name, target_name) target_data = target_data[1] / target_data[0]
metrics[module_name] = scaler.shrink(target_data, reduce_func) scaler = self._get_scaler(module_name, target_name)
metrics[module_name][target_name] = scaler.shrink(target_data, reduce_func)
return metrics return metrics
...@@ -17,7 +17,8 @@ _logger = logging.getLogger(__name__) ...@@ -17,7 +17,8 @@ _logger = logging.getLogger(__name__)
class AMCEnv: class AMCEnv:
def __init__(self, model: Module, config_list: List[Dict], dummy_input: Tensor, total_sparsity: float, max_sparsity_per_layer: Dict[str, float], target: str = 'flops'): def __init__(self, model: Module, config_list: List[Dict], dummy_input: Tensor, total_sparsity: float,
max_sparsity_per_layer: Dict[str, float], target: str = 'flops'):
pruning_op_names = [] pruning_op_names = []
[pruning_op_names.extend(config['op_names']) for config in config_list_canonical(model, config_list)] [pruning_op_names.extend(config['op_names']) for config in config_list_canonical(model, config_list)]
self.pruning_ops = OrderedDict() self.pruning_ops = OrderedDict()
...@@ -26,7 +27,10 @@ class AMCEnv: ...@@ -26,7 +27,10 @@ class AMCEnv:
if name in pruning_op_names: if name in pruning_op_names:
op_type = type(layer).__name__ op_type = type(layer).__name__
stride = np.power(np.prod(layer.stride), 1 / len(layer.stride)) if hasattr(layer, 'stride') else 0 # type: ignore stride = np.power(np.prod(layer.stride), 1 / len(layer.stride)) if hasattr(layer, 'stride') else 0 # type: ignore
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) if hasattr(layer, 'kernel_size') else 1 # type: ignore if hasattr(layer, 'kernel_size'):
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) # type: ignore
else:
kernel_size = 1
self.pruning_ops[name] = (i, op_type, stride, kernel_size) self.pruning_ops[name] = (i, op_type, stride, kernel_size)
self.pruning_types.append(op_type) self.pruning_types.append(op_type)
self.pruning_types = list(set(self.pruning_types)) self.pruning_types = list(set(self.pruning_types))
...@@ -60,15 +64,18 @@ class AMCEnv: ...@@ -60,15 +64,18 @@ class AMCEnv:
total_current_target = sum([current_statistics[name][self.target] for name in self.pruning_op_names]) total_current_target = sum([current_statistics[name][self.target] for name in self.pruning_op_names])
previous_pruning_target = self.under_pruning_target - total_current_target previous_pruning_target = self.under_pruning_target - total_current_target
max_rest_pruning_target = sum([current_statistics[name][self.target] * self.max_sparsity_per_layer[name] for name in self.pruning_op_names[index + 1:]]) max_rest_pruning_target = sum([current_statistics[name][self.target] * self.max_sparsity_per_layer[name] \
for name in self.pruning_op_names[index + 1:]])
min_current_pruning_target = self.excepted_pruning_target - previous_pruning_target - max_rest_pruning_target min_current_pruning_target = self.excepted_pruning_target - previous_pruning_target - max_rest_pruning_target
max_current_pruning_target_1 = self.origin_statistics[op_name][self.target] * self.max_sparsity_per_layer[op_name] - (self.origin_statistics[op_name][self.target] - current_statistics[op_name][self.target]) max_current_pruning_target_1 = self.origin_statistics[op_name][self.target] * self.max_sparsity_per_layer[op_name] - \
(self.origin_statistics[op_name][self.target] - current_statistics[op_name][self.target])
max_current_pruning_target_2 = self.excepted_pruning_target - previous_pruning_target max_current_pruning_target_2 = self.excepted_pruning_target - previous_pruning_target
max_current_pruning_target = min(max_current_pruning_target_1, max_current_pruning_target_2) max_current_pruning_target = min(max_current_pruning_target_1, max_current_pruning_target_2)
min_action = min_current_pruning_target / current_statistics[op_name][self.target] min_action = min_current_pruning_target / current_statistics[op_name][self.target]
max_action = max_current_pruning_target / current_statistics[op_name][self.target] max_action = max_current_pruning_target / current_statistics[op_name][self.target]
if min_action > self.max_sparsity_per_layer[op_name]: if min_action > self.max_sparsity_per_layer[op_name]:
_logger.warning('[%s] min action > max sparsity per layer: %f > %f', op_name, min_action, self.max_sparsity_per_layer[op_name]) warn_msg = f'[{op_name}] min action > max sparsity per layer: {min_action} > {self.max_sparsity_per_layer[op_name]}'
_logger.warning(warn_msg)
action = max(0., min(max_action, max(min_action, action))) action = max(0., min(max_action, max(min_action, action)))
self.current_op_name = op_name self.current_op_name = op_name
......
...@@ -51,7 +51,7 @@ class FunctionBasedTaskGenerator(TaskGenerator): ...@@ -51,7 +51,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
self.total_iteration = total_iteration self.total_iteration = total_iteration
self.skip_first_iteration = skip_first_iteration self.skip_first_iteration = skip_first_iteration
super().__init__(origin_model, origin_config_list=origin_config_list, origin_masks=origin_masks, super().__init__(origin_model, origin_config_list=origin_config_list, origin_masks=origin_masks,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result) log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='latest')
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}): def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_iteration = 1 if self.skip_first_iteration else 0 self.current_iteration = 1 if self.skip_first_iteration else 0
...@@ -78,10 +78,14 @@ class FunctionBasedTaskGenerator(TaskGenerator): ...@@ -78,10 +78,14 @@ class FunctionBasedTaskGenerator(TaskGenerator):
# get current2origin_sparsity and compact2origin_sparsity # get current2origin_sparsity and compact2origin_sparsity
origin_model = torch.load(self._origin_model_path) origin_model = torch.load(self._origin_model_path)
current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.target_sparsity) current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(origin_model, compact_model, compact_model_masks,
_logger.debug('\nTask %s total real sparsity compared with original model is:\n%s', str(task_result.task_id), json_tricks.dumps(current2origin_sparsity, indent=4)) self.target_sparsity)
debug_msg = f'\nTask {task_result.task_id} total real sparsity compared with original model is:\n' + \
f'{json_tricks.dumps(current2origin_sparsity, indent=4)}'
_logger.debug(debug_msg)
if task_result.task_id != 'origin': if task_result.task_id != 'origin':
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity task = self._tasks[task_result.task_id]
task.state['current2origin_sparsity'] = current2origin_sparsity
# if reach the total_iteration, no more task will be generated # if reach the total_iteration, no more task will be generated
if self.current_iteration > self.total_iteration: if self.current_iteration > self.total_iteration:
...@@ -116,7 +120,8 @@ class AGPTaskGenerator(FunctionBasedTaskGenerator): ...@@ -116,7 +120,8 @@ class AGPTaskGenerator(FunctionBasedTaskGenerator):
for target, mo in zip(target_sparsity, compact2origin_sparsity): for target, mo in zip(target_sparsity, compact2origin_sparsity):
ori_sparsity = (1 - (1 - iteration / self.total_iteration) ** 3) * target['total_sparsity'] ori_sparsity = (1 - (1 - iteration / self.total_iteration) ** 3) * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity'])) sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity']) err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
assert 0 <= sparsity <= 1, err_msg
config_list.append(deepcopy(target)) config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity config_list[-1]['total_sparsity'] = sparsity
return config_list return config_list
...@@ -128,7 +133,8 @@ class LinearTaskGenerator(FunctionBasedTaskGenerator): ...@@ -128,7 +133,8 @@ class LinearTaskGenerator(FunctionBasedTaskGenerator):
for target, mo in zip(target_sparsity, compact2origin_sparsity): for target, mo in zip(target_sparsity, compact2origin_sparsity):
ori_sparsity = iteration / self.total_iteration * target['total_sparsity'] ori_sparsity = iteration / self.total_iteration * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity'])) sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity']) err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
assert 0 <= sparsity <= 1, err_msg
config_list.append(deepcopy(target)) config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity config_list[-1]['total_sparsity'] = sparsity
return config_list return config_list
...@@ -149,16 +155,18 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator): ...@@ -149,16 +155,18 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
# The following is the formula in paper. # The following is the formula in paper.
# ori_sparsity = (target['total_sparsity'] * 100) ** (iteration / self.total_iteration) / 100 # ori_sparsity = (target['total_sparsity'] * 100) ** (iteration / self.total_iteration) / 100
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity'])) sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity']) err_msg = 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
assert 0 <= sparsity <= 1, err_msg
config_list.append(deepcopy(target)) config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity config_list[-1]['total_sparsity'] = sparsity
return config_list return config_list
class SimulatedAnnealingTaskGenerator(TaskGenerator): class SimulatedAnnealingTaskGenerator(TaskGenerator):
def __init__(self, origin_model: Optional[Module], origin_config_list: Optional[List[Dict]], origin_masks: Dict[str, Dict[str, Tensor]] = {}, def __init__(self, origin_model: Optional[Module], origin_config_list: Optional[List[Dict]],
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9, origin_masks: Dict[str, Dict[str, Tensor]] = {}, start_temperature: float = 100, stop_temperature: float = 20,
perturbation_magnitude: float = 0.35, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False): cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35, log_dir: Union[str, Path] = '.',
keep_intermediate_result: bool = False):
""" """
Parameters Parameters
---------- ----------
...@@ -188,7 +196,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -188,7 +196,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
self.perturbation_magnitude = perturbation_magnitude self.perturbation_magnitude = perturbation_magnitude
super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list, super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result) log_dir=log_dir, keep_intermediate_result=keep_intermediate_result, best_result_mode='maximize')
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}): def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_temperature = self.start_temperature self.current_temperature = self.start_temperature
...@@ -196,7 +204,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -196,7 +204,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
# TODO: replace with validation here # TODO: replace with validation here
for config in config_list: for config in config_list:
if 'sparsity' in config or 'sparsity_per_layer' in config: if 'sparsity' in config or 'sparsity_per_layer' in config:
_logger.warning('Only `total_sparsity` can be differentially allocated sparse ratio to each layer, `sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. Make sure you know what this will lead to, otherwise please use `total_sparsity`.') warn_msg = 'Only `total_sparsity` can be differentially allocated sparse ratio to each layer, ' + \
'`sparsity` or `sparsity_per_layer` will allocate fixed sparse ratio to layers. ' + \
'Make sure you know what this will lead to, otherwise please use `total_sparsity`.'
_logger.warning(warn_msg)
self.weights_numel, self.masked_rate = get_model_weights_numel(model, config_list, masks) self.weights_numel, self.masked_rate = get_model_weights_numel(model, config_list, masks)
self.target_sparsity_list = config_list_canonical(model, config_list) self.target_sparsity_list = config_list_canonical(model, config_list)
...@@ -259,11 +270,11 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -259,11 +270,11 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
num_weights = sorted([self.weights_numel[op_name] for op_name in op_names]) num_weights = sorted([self.weights_numel[op_name] for op_name in op_names])
sparsity = sorted(random_sparsity) sparsity = sorted(random_sparsity)
# calculate the scale # calculate the scale
total_weights = np.sum(num_weights) total_weights = np.sum(num_weights)
total_weights_pruned = np.sum([int(num_weight * sparsity[idx]) for idx, num_weight in enumerate(num_weights)]) total_weights_pruned = np.sum([int(num_weight * sparsity[idx]) for idx, num_weight in enumerate(num_weights)])
if total_weights_pruned == 0: if total_weights_pruned == 0:
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment