Unverified Commit ed455174 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] evaluator - step2 (#4992)

parent a689e619
...@@ -11,7 +11,7 @@ from torch.nn import Module ...@@ -11,7 +11,7 @@ from torch.nn import Module
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from nni.common.serializer import Traceable, is_traceable from nni.common.serializer import is_traceable
__all__ = ['OptimizerConstructHelper', 'LRSchedulerConstructHelper'] __all__ = ['OptimizerConstructHelper', 'LRSchedulerConstructHelper']
...@@ -86,7 +86,8 @@ class OptimizerConstructHelper(ConstructHelper): ...@@ -86,7 +86,8 @@ class OptimizerConstructHelper(ConstructHelper):
'Please use nni.trace to wrap the optimizer class before initialize the optimizer.' 'Please use nni.trace to wrap the optimizer class before initialize the optimizer.'
assert isinstance(optimizer_trace, Optimizer), \ assert isinstance(optimizer_trace, Optimizer), \
'It is not an instance of torch.nn.Optimizer.' 'It is not an instance of torch.nn.Optimizer.'
return OptimizerConstructHelper(model, optimizer_trace.trace_symbol, *optimizer_trace.trace_args, **optimizer_trace.trace_kwargs) # type: ignore return OptimizerConstructHelper(model, optimizer_trace.trace_symbol, *optimizer_trace.trace_args, # type: ignore
**optimizer_trace.trace_kwargs) # type: ignore
class LRSchedulerConstructHelper(ConstructHelper): class LRSchedulerConstructHelper(ConstructHelper):
...@@ -115,4 +116,5 @@ class LRSchedulerConstructHelper(ConstructHelper): ...@@ -115,4 +116,5 @@ class LRSchedulerConstructHelper(ConstructHelper):
'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.' 'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.'
assert isinstance(lr_scheduler_trace, _LRScheduler), \ assert isinstance(lr_scheduler_trace, _LRScheduler), \
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.' 'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, *lr_scheduler_trace.trace_args, **lr_scheduler_trace.trace_kwargs) # type: ignore return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, *lr_scheduler_trace.trace_args, # type: ignore
**lr_scheduler_trace.trace_kwargs) # type: ignore
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
_EVALUATOR_DOCSTRING = r"""NNI will use the evaluator to intervene in the model training process,
so as to perform training-aware model compression.
All training-aware model compression will use the evaluator as the entry for intervention training in the future.
Usually you just need to wrap some classes with ``nni.trace`` or package the training process as a function to initialize the evaluator.
Please refer ... for a full tutorial on how to initialize a ``evaluator``.
The following are two simple examples, if you use pytorch_lightning, please refer to :class:`nni.compression.pytorch.LightningEvaluator`,
if you use native pytorch, please refer to :class:`nni.compression.pytorch.TorchEvaluator`::
# LightningEvaluator example
import pytorch_lightning
lightning_trainer = nni.trace(pytorch_lightning.Trainer)(max_epochs=1, max_steps=50, logger=TensorBoardLogger(...))
lightning_data_module = nni.trace(pytorch_lightning.LightningDataModule)(...)
from nni.compression.pytorch import LightningEvaluator
evaluator = LightningEvaluator(lightning_trainer, lightning_data_module)
# TorchEvaluator example
import torch
import torch.nn.functional as F
def training_model(model, optimizer, criterion, lr_scheduler, max_steps, max_epochs, *args, **kwargs):
# max_steps, max_epochs might be None, which means unlimited training time,
# so here we need set a default termination condition (by default, total_epochs=10, total_steps=100000).
total_epochs = max_epochs if max_epochs else 10
total_steps = max_steps if max_steps else 100000
current_step = 0
# init dataloader
train_dataloader = ...
for epoch in range(total_epochs):
...
for input_data, target in train_dataloader:
optimizer.zero_grad()
result = model(input_data)
loss = criterion(result, target)
loss.backward()
optimizer.step()
current_step += 1
if current_step >= total_steps:
return
lr_scheduler.step()
traced_optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01)
criterion = F.nll_loss
from nni.compression.pytorch import TorchEvaluator
evaluator = TorchEvaluator(training_func=training_model, optimziers=traced_optimizer, criterion=criterion)
"""
...@@ -73,16 +73,16 @@ class TensorHook(Hook): ...@@ -73,16 +73,16 @@ class TensorHook(Hook):
return hook return hook
""" """
def __init__(self, target: Tensor, target_name: str, hook_factory: Callable[[List], Callable[[Tensor], Any]]): def __init__(self, target: Tensor, target_name: str, hook_factory: Callable[[List], Callable[[Tensor], Tensor | None]]):
assert isinstance(target, Tensor) assert isinstance(target, Tensor)
super().__init__(target, target_name, hook_factory) super().__init__(target, target_name, hook_factory)
def _register(self, hook_func: Callable[[Tensor], Any]) -> RemovableHandle: def _register(self, hook_func: Callable[[Tensor], Tensor | None]) -> RemovableHandle:
return self.target.register_hook(hook_func) # type: ignore return self.target.register_hook(hook_func) # type: ignore
class ModuleHook(Hook): class ModuleHook(Hook):
def __init__(self, target: Module, target_name: str, hook_factory: Callable[[List], Callable[[Module, Tensor, Tensor], Any]]): def __init__(self, target: Module, target_name: str, hook_factory: Callable[[List], Callable[[Module, Any, Any], Any]]):
assert isinstance(target, Module) assert isinstance(target, Module)
super().__init__(target, target_name, hook_factory) super().__init__(target, target_name, hook_factory)
...@@ -97,7 +97,7 @@ class ForwardHook(ModuleHook): ...@@ -97,7 +97,7 @@ class ForwardHook(ModuleHook):
return hook return hook
""" """
def _register(self, hook_func: Callable[[Module, Tensor, Tensor], Any]): def _register(self, hook_func: Callable[[Module, Tuple[Any], Any], Any]):
return self.target.register_forward_hook(hook_func) # type: ignore return self.target.register_forward_hook(hook_func) # type: ignore
...@@ -111,7 +111,7 @@ class BackwardHook(ModuleHook): ...@@ -111,7 +111,7 @@ class BackwardHook(ModuleHook):
return hook return hook
""" """
def _register(self, hook_func: Callable[[Module, Tensor, Tensor], Any]): def _register(self, hook_func: Callable[[Module, Tuple[Tensor] | Tensor, Tuple[Tensor] | Tensor], Any]):
return self.target.register_backward_hook(hook_func) # type: ignore return self.target.register_backward_hook(hook_func) # type: ignore
...@@ -148,7 +148,8 @@ class Evaluator: ...@@ -148,7 +148,8 @@ class Evaluator:
def bind_model(self, model: Module | pl.LightningModule, param_names_map: Dict[str, str] | None = None): def bind_model(self, model: Module | pl.LightningModule, param_names_map: Dict[str, str] | None = None):
""" """
Bind the model suitable for this ``Evaluator`` to use the evaluator's abilities of model modification, model training, and model evaluation. Bind the model suitable for this ``Evaluator`` to use the evaluator's abilities of model modification,
model training, and model evaluation.
Parameter Parameter
--------- ---------
...@@ -246,10 +247,12 @@ class Evaluator: ...@@ -246,10 +247,12 @@ class Evaluator:
def evaluate(self) -> float | None | Tuple[float, Any] | Tuple[None, Any]: def evaluate(self) -> float | None | Tuple[float, Any] | Tuple[None, Any]:
""" """
NNI assume the evaluation function user passed in should return a float number or a dict as metric. NNI assume the evaluation function user passed in should return a float number or a dict as metric.
If the evaluation function returned a dict, take the value with dict key ``default`` as the first element of ``evaluate`` returned value, If the evaluation function returned a dict, take the value with dict key ``default``
as the first element of ``evaluate`` returned value,
and put the dict as the second element of the returned value. and put the dict as the second element of the returned value.
For any other type of the metric returned by evaluation function, ``evaluate`` will directly returned For any other type of the metric returned by evaluation function, ``evaluate`` will directly returned
(it should be a float, but NNI does not prevent other types from being returned, this will handle by the object calling ``evaluate``). (it should be a float, but NNI does not prevent other types from being returned,
this will handle by the object calling ``evaluate``).
""" """
# Note that the first item of the returned value will be used as the default metric used by NNI. # Note that the first item of the returned value will be used as the default metric used by NNI.
raise NotImplementedError raise NotImplementedError
...@@ -287,9 +290,11 @@ class LightningEvaluator(Evaluator): ...@@ -287,9 +290,11 @@ class LightningEvaluator(Evaluator):
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule, def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
dummy_input: Any | None = None): dummy_input: Any | None = None):
err_msg = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.' err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer') err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
assert isinstance(data_module, pl.LightningDataModule) and is_traceable(data_module), err_msg.format('pytorch_lightning.LightningDataModule', 'pytorch_lightning.LightningDataModule') assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
err_msg = err_msg_p.format('pytorch_lightning.LightningDataModule', 'pytorch_lightning.LightningDataModule')
assert isinstance(data_module, pl.LightningDataModule) and is_traceable(data_module), err_msg
self.trainer = trainer self.trainer = trainer
self.data_module = data_module self.data_module = data_module
self._dummy_input = dummy_input self._dummy_input = dummy_input
...@@ -314,18 +319,20 @@ class LightningEvaluator(Evaluator): ...@@ -314,18 +319,20 @@ class LightningEvaluator(Evaluator):
optimizers_lr_schedulers: Any = pure_model.configure_optimizers() optimizers_lr_schedulers: Any = pure_model.configure_optimizers()
# 1. None - Fit will run without any optimizer. # 1. None - Fit will run without any optimizer.
if optimizers_lr_schedulers is None: if optimizers_lr_schedulers is None:
err_msg = 'NNI does not support `LightningModule.configure_optimizers` returned None, ' err_msg = 'NNI does not support `LightningModule.configure_optimizers` returned None, ' + \
err_msg += 'if you have a reason why you must, please file an issue at https://github.com/microsoft/nni/issues' 'if you have a reason why you must, please file an issue at https://github.com/microsoft/nni/issues'
raise ValueError(err_msg) raise ValueError(err_msg)
# 2. Single optimizer. # 2. Single optimizer.
# 3. Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config. # 3. Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose
# value is a single LR scheduler or lr_scheduler_config.
elif isinstance(optimizers_lr_schedulers, (Optimizer, dict)): elif isinstance(optimizers_lr_schedulers, (Optimizer, dict)):
optimizers_lr_schedulers = [optimizers_lr_schedulers] optimizers_lr_schedulers = [optimizers_lr_schedulers]
err_msg = f'Got an wrong returned value type of `LightningModule.configure_optimizers`: {type(optimizers_lr_schedulers).__name__}' err_msg = f'Got an wrong returned value type of `LightningModule.configure_optimizers`: {type(optimizers_lr_schedulers).__name__}'
assert isinstance(optimizers_lr_schedulers, (list, tuple)), err_msg assert isinstance(optimizers_lr_schedulers, (list, tuple)), err_msg
# 4. Two lists - the first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config). # 4. Two lists - the first list has multiple optimizers,
# and the second has multiple LR schedulers (or multiple lr_scheduler_config).
if isinstance(optimizers_lr_schedulers[0], (list, tuple)): if isinstance(optimizers_lr_schedulers[0], (list, tuple)):
optimizers, lr_schedulers = optimizers_lr_schedulers optimizers, lr_schedulers = optimizers_lr_schedulers
self._optimizer_helpers = [OptimizerConstructHelper.from_trace(pure_model, optimizer) for optimizer in optimizers] self._optimizer_helpers = [OptimizerConstructHelper.from_trace(pure_model, optimizer) for optimizer in optimizers]
...@@ -364,7 +371,8 @@ class LightningEvaluator(Evaluator): ...@@ -364,7 +371,8 @@ class LightningEvaluator(Evaluator):
self._initialization_complete = True self._initialization_complete = True
def bind_model(self, model: pl.LightningModule, param_names_map: Dict[str, str] | None = None): def bind_model(self, model: pl.LightningModule, param_names_map: Dict[str, str] | None = None):
assert self._initialization_complete is True, 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.' err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
assert self._initialization_complete is True, err_msg
assert isinstance(model, pl.LightningModule) assert isinstance(model, pl.LightningModule)
if self.model is not None: if self.model is not None:
_logger.warning('Already bound a model, will unbind it before bind a new model.') _logger.warning('Already bound a model, will unbind it before bind a new model.')
...@@ -397,7 +405,8 @@ class LightningEvaluator(Evaluator): ...@@ -397,7 +405,8 @@ class LightningEvaluator(Evaluator):
if self._opt_returned_dicts: if self._opt_returned_dicts:
def new_configure_optimizers(_): # type: ignore def new_configure_optimizers(_): # type: ignore
optimizers = [opt_helper.call(self.model, self._param_names_map) for opt_helper in self._optimizer_helpers] # type: ignore optimizers = [opt_helper.call(self.model, self._param_names_map) for opt_helper in self._optimizer_helpers] # type: ignore
lr_schedulers = [lrs_helper.call(optimizers[self._lrs_opt_map[i]]) for i, lrs_helper in enumerate(self._lr_scheduler_helpers)] lr_schedulers = [lrs_helper.call(optimizers[self._lrs_opt_map[i]])
for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
opt_lrs_dicts = deepcopy(self._opt_returned_dicts) opt_lrs_dicts = deepcopy(self._opt_returned_dicts)
for opt_lrs_dict in opt_lrs_dicts: for opt_lrs_dict in opt_lrs_dicts:
opt_lrs_dict['optimizer'] = optimizers[opt_lrs_dict['optimizer']] opt_lrs_dict['optimizer'] = optimizers[opt_lrs_dict['optimizer']]
...@@ -407,7 +416,8 @@ class LightningEvaluator(Evaluator): ...@@ -407,7 +416,8 @@ class LightningEvaluator(Evaluator):
elif self._lr_scheduler_helpers: elif self._lr_scheduler_helpers:
def new_configure_optimizers(_): # type: ignore def new_configure_optimizers(_): # type: ignore
optimizers = [opt_helper.call(self.model, self._param_names_map) for opt_helper in self._optimizer_helpers] # type: ignore optimizers = [opt_helper.call(self.model, self._param_names_map) for opt_helper in self._optimizer_helpers] # type: ignore
lr_schedulers = [lrs_helper.call(optimizers[self._lrs_opt_map[i]]) for i, lrs_helper in enumerate(self._lr_scheduler_helpers)] lr_schedulers = [lrs_helper.call(optimizers[self._lrs_opt_map[i]])
for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
return optimizers, lr_schedulers return optimizers, lr_schedulers
else: else:
def new_configure_optimizers(_): def new_configure_optimizers(_):
...@@ -442,7 +452,8 @@ class LightningEvaluator(Evaluator): ...@@ -442,7 +452,8 @@ class LightningEvaluator(Evaluator):
assert isinstance(self.model, pl.LightningModule) assert isinstance(self.model, pl.LightningModule)
class OptimizerCallback(Callback): class OptimizerCallback(Callback):
def on_before_optimizer_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Optimizer, opt_idx: int) -> None: def on_before_optimizer_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule,
optimizer: Optimizer, opt_idx: int) -> None:
for task in before_step_tasks: for task in before_step_tasks:
task() task()
...@@ -486,10 +497,12 @@ class LightningEvaluator(Evaluator): ...@@ -486,10 +497,12 @@ class LightningEvaluator(Evaluator):
def evaluate(self) -> Tuple[float | None, List[Dict[str, float]]]: def evaluate(self) -> Tuple[float | None, List[Dict[str, float]]]:
""" """
NNI will use metric with key ``default`` for evaluating model, please make sure you have this key in your ``Trainer.test()`` returned metric dicts. NNI will use metric with key ``default`` for evaluating model,
If ``Trainer.test()`` returned list contains multiple dicts with key ``default``, NNI will take their average as the final metric. please make sure you have this key in your ``Trainer.test()`` returned metric dicts.
E.g., if ``Trainer.test()`` returned ``[{'default': 0.8, 'loss': 2.3}, {'default': 0.6, 'loss': 2.4}, {'default': 0.7, 'loss': 2.3}]``, If ``Trainer.test()`` returned list contains multiple dicts with key ``default``,
NNI will take the final metric ``(0.8 + 0.6 + 0.7) / 3 = 0.7``. NNI will take their average as the final metric.
E.g., if ``Trainer.test()`` returned ``[{'default': 0.8, 'loss': 2.3}, {'default': 0.6, 'loss': 2.4}]``,
NNI will take the final metric ``(0.8 + 0.6) / 2 = 0.7``.
""" """
assert isinstance(self.model, pl.LightningModule) assert isinstance(self.model, pl.LightningModule)
# reset trainer # reset trainer
...@@ -514,9 +527,11 @@ class LightningEvaluator(Evaluator): ...@@ -514,9 +527,11 @@ class LightningEvaluator(Evaluator):
raise e raise e
_OPTIMIZERS = Union[Optimizer, List[Optimizer]]
_CRITERION = Callable[[Any, Any], Any] _CRITERION = Callable[[Any, Any], Any]
_SCHEDULERS = Union[None, _LRScheduler, List[_LRScheduler]]
_EVALUATING_FUNC = Callable[[Module], Union[float, Dict]] _EVALUATING_FUNC = Callable[[Module], Union[float, Dict]]
_TRAINING_FUNC = Callable[[Module, Union[Optimizer, List[Optimizer]], _CRITERION, Union[None, _LRScheduler, List[_LRScheduler]], Optional[int], Optional[int]], None] _TRAINING_FUNC = Callable[[Module, _OPTIMIZERS, _CRITERION, _SCHEDULERS, Optional[int], Optional[int]], None]
class TorchEvaluator(Evaluator): class TorchEvaluator(Evaluator):
...@@ -528,8 +543,10 @@ class TorchEvaluator(Evaluator): ...@@ -528,8 +543,10 @@ class TorchEvaluator(Evaluator):
---------- ----------
training_func training_func
The training function is used to train the model, note that this a entire optimization training loop. The training function is used to train the model, note that this a entire optimization training loop.
It should have three required parameters [model, optimizers, criterion] and three optional parameters [schedulers, max_steps, max_epochs]. It should have three required parameters [model, optimizers, criterion]
``optimizers`` can be an instance of ``torch.optim.Optimizer`` or a list of ``torch.optim.Optimizer``, it belongs to the ``optimizers`` pass to ``TorchEvaluator``. and three optional parameters [schedulers, max_steps, max_epochs].
``optimizers`` can be an instance of ``torch.optim.Optimizer`` or a list of ``torch.optim.Optimizer``,
it belongs to the ``optimizers`` pass to ``TorchEvaluator``.
``criterion`` and ``schedulers`` are also belonging to the ``criterion`` and ``schedulers`` pass to ``TorchEvaluator``. ``criterion`` and ``schedulers`` are also belonging to the ``criterion`` and ``schedulers`` pass to ``TorchEvaluator``.
``max_steps`` and ``max_epochs`` are used to control the training duration. ``max_steps`` and ``max_epochs`` are used to control the training duration.
...@@ -574,7 +591,8 @@ class TorchEvaluator(Evaluator): ...@@ -574,7 +591,8 @@ class TorchEvaluator(Evaluator):
Optional. The traced _LRScheduler instance which the lr scheduler class is wrapped by nni.trace. Optional. The traced _LRScheduler instance which the lr scheduler class is wrapped by nni.trace.
E.g. ``traced_lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)``. E.g. ``traced_lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)``.
dummy_input dummy_input
Optional. The dummy_input is used to trace the graph, the same with ``example_inputs`` in ``torch.jit.trace(func, example_inputs, ...)``. Optional. The dummy_input is used to trace the graph,
the same with ``example_inputs`` in ``torch.jit.trace(func, example_inputs, ...)``.
evaluating_func evaluating_func
Optional. A function that input is model and return the evaluation metric. Optional. A function that input is model and return the evaluation metric.
The return value can be a single float or a tuple (float, Any). The return value can be a single float or a tuple (float, Any).
...@@ -634,14 +652,16 @@ class TorchEvaluator(Evaluator): ...@@ -634,14 +652,16 @@ class TorchEvaluator(Evaluator):
self._lr_scheduler_helpers = [LRSchedulerConstructHelper.from_trace(lr_scheduler) for lr_scheduler in self._tmp_lr_schedulers] self._lr_scheduler_helpers = [LRSchedulerConstructHelper.from_trace(lr_scheduler) for lr_scheduler in self._tmp_lr_schedulers]
optimizer_ids_map = {id(optimizer): i for i, optimizer in enumerate(self._tmp_optimizers)} optimizer_ids_map = {id(optimizer): i for i, optimizer in enumerate(self._tmp_optimizers)}
# record i-th lr_scheduler scheduling j-th optimizer lr # record i-th lr_scheduler scheduling j-th optimizer lr
self._lrs_opt_map = {i: optimizer_ids_map[id(lr_scheduler.optimizer)] for i, lr_scheduler in enumerate(self._tmp_lr_schedulers)} # type: ignore self._lrs_opt_map = {i: optimizer_ids_map[id(lr_scheduler.optimizer)] # type: ignore
for i, lr_scheduler in enumerate(self._tmp_lr_schedulers)} # type: ignore
delattr(self, '_tmp_optimizers') delattr(self, '_tmp_optimizers')
delattr(self, '_tmp_lr_schedulers') delattr(self, '_tmp_lr_schedulers')
self._initialization_complete = True self._initialization_complete = True
def bind_model(self, model: Module, param_names_map: Dict[str, str] | None = None): def bind_model(self, model: Module, param_names_map: Dict[str, str] | None = None):
assert self._initialization_complete is True, 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.' err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
assert self._initialization_complete is True, err_msg
assert isinstance(model, Module) assert isinstance(model, Module)
if self.model is not None: if self.model is not None:
_logger.warning('Already bound a model, will unbind it before bind a new model.') _logger.warning('Already bound a model, will unbind it before bind a new model.')
...@@ -651,7 +671,8 @@ class TorchEvaluator(Evaluator): ...@@ -651,7 +671,8 @@ class TorchEvaluator(Evaluator):
self._param_names_map = param_names_map self._param_names_map = param_names_map
# initialize optimizers & lr_schedulers for the bound model here # initialize optimizers & lr_schedulers for the bound model here
self._optimizers = [helper.call(model, param_names_map) for helper in self._optimizer_helpers] self._optimizers = [helper.call(model, param_names_map) for helper in self._optimizer_helpers]
self._lr_schedulers = [lrs_helper.call(self._optimizers[self._lrs_opt_map[i]]) for i, lrs_helper in enumerate(self._lr_scheduler_helpers)] self._lr_schedulers = [lrs_helper.call(self._optimizers[self._lrs_opt_map[i]]) \
for i, lrs_helper in enumerate(self._lr_scheduler_helpers)]
self._first_optimizer_step = self._optimizers[0].step self._first_optimizer_step = self._optimizers[0].step
def unbind_model(self): def unbind_model(self):
...@@ -717,7 +738,8 @@ class TorchEvaluator(Evaluator): ...@@ -717,7 +738,8 @@ class TorchEvaluator(Evaluator):
if isinstance(metric, dict): if isinstance(metric, dict):
nni_used_metric = metric.get('default', None) nni_used_metric = metric.get('default', None)
if nni_used_metric is None: if nni_used_metric is None:
warn_msg = f'Evaluation function returns a dict metric without key `default`, will return None as the model evaluation metric value.' warn_msg = f'Evaluation function returns a dict metric without key `default`,' + \
'will return None as the model evaluation metric value.'
_logger.warning(warn_msg) _logger.warning(warn_msg)
return nni_used_metric, metric return nni_used_metric, metric
else: else:
......
...@@ -229,7 +229,8 @@ def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_ ...@@ -229,7 +229,8 @@ def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_
return current2origin_sparsity, compact2origin_sparsity, mask2compact_sparsity return current2origin_sparsity, compact2origin_sparsity, mask2compact_sparsity
def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}) -> Tuple[Dict[str, int], Dict[str, float]]: def get_model_weights_numel(model: Module, config_list: List[Dict],
masks: Dict[str, Dict[str, Tensor]] = {}) -> Tuple[Dict[str, int], Dict[str, float]]:
""" """
Count the layer weight elements number in config_list. Count the layer weight elements number in config_list.
If masks is not empty, the masked weight will not be counted. If masks is not empty, the masked weight will not be counted.
......
...@@ -53,18 +53,24 @@ class Scaling: ...@@ -53,18 +53,24 @@ class Scaling:
# for the `-1` in kernel_size, then expand size (4, 3, 1) to size (4, 6, 2). # for the `-1` in kernel_size, then expand size (4, 3, 1) to size (4, 6, 2).
kernel_padding_mode kernel_padding_mode
'front' or 'back', default is 'front'. 'front' or 'back', default is 'front'.
If set 'front', for a given tensor when shrinking, padding `1` at front of kernel_size until `len(tensor.shape) == len(kernel_size)`; If set 'front', for a given tensor when shrinking,
for a given expand size when expanding, padding `1` at front of kernel_size until `len(expand_size) == len(kernel_size)`. padding `1` at front of kernel_size until `len(tensor.shape) == len(kernel_size)`;
If set 'back', for a given tensor when shrinking, padding `-1` at back of kernel_size until `len(tensor.shape) == len(kernel_size)`; for a given expand size when expanding,
for a given expand size when expanding, padding `-1` at back of kernel_size until `len(expand_size) == len(kernel_size)`. padding `1` at front of kernel_size until `len(expand_size) == len(kernel_size)`.
If set 'back', for a given tensor when shrinking,
padding `-1` at back of kernel_size until `len(tensor.shape) == len(kernel_size)`;
for a given expand size when expanding,
padding `-1` at back of kernel_size until `len(expand_size) == len(kernel_size)`.
""" """
def __init__(self, kernel_size: List[int], kernel_padding_mode: Literal['front', 'back'] = 'front') -> None: def __init__(self, kernel_size: List[int], kernel_padding_mode: Literal['front', 'back'] = 'front') -> None:
self.kernel_size = kernel_size self.kernel_size = kernel_size
assert kernel_padding_mode in ['front', 'back'], f"kernel_padding_mode should be one of ['front', 'back'], but get kernel_padding_mode={kernel_padding_mode}." err_msg = f"kernel_padding_mode should be one of ['front', 'back'], but get kernel_padding_mode={kernel_padding_mode}."
assert kernel_padding_mode in ['front', 'back'], err_msg
self.kernel_padding_mode = kernel_padding_mode self.kernel_padding_mode = kernel_padding_mode
def _padding(self, _list: List[int], length: int, padding_value: int = -1, padding_mode: Literal['front', 'back'] = 'back') -> List[int]: def _padding(self, _list: List[int], length: int, padding_value: int = -1,
padding_mode: Literal['front', 'back'] = 'back') -> List[int]:
""" """
Padding the `_list` to a specific length with `padding_value`. Padding the `_list` to a specific length with `padding_value`.
...@@ -144,10 +150,12 @@ class Scaling: ...@@ -144,10 +150,12 @@ class Scaling:
assert b % a == 0, f'Can not expand tensor with {target.shape} to {expand_size} with kernel size {kernel_size}.' assert b % a == 0, f'Can not expand tensor with {target.shape} to {expand_size} with kernel size {kernel_size}.'
_expand_size.append(b // a) _expand_size.append(b // a)
_expand_size.append(a) _expand_size.append(a)
new_target: Tensor = reduce(lambda t, dim: t.unsqueeze(dim), [new_target] + [2 * _ + 1 for _ in range(len(expand_size))]) # type: ignore new_target: Tensor = reduce(lambda t, dim: t.unsqueeze(dim),
[new_target] + [2 * _ + 1 for _ in range(len(expand_size))]) # type: ignore
# step 3: expanding the new target to _expand_size and reshape to expand_size. # step 3: expanding the new target to _expand_size and reshape to expand_size.
# Note that we can also give an interface for how to expand the tensor, like `reduce_func` in `_shrink`, currently we don't have that need. # Note that we can also give an interface for how to expand the tensor, like `reduce_func` in `_shrink`,
# currently we don't have that need.
result = new_target.expand(_expand_size).reshape(expand_size).clone() result = new_target.expand(_expand_size).reshape(expand_size).clone()
return result return result
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from nni.algorithms.compression.v2.pytorch import TorchEvaluator, LightningEvaluator
from .speedup import ModelSpeedup from .speedup import ModelSpeedup
from .compressor import Compressor, Pruner, Quantizer from .compressor import Compressor, Pruner, Quantizer
from .utils.apply_compression import apply_compression_results from .utils.apply_compression import apply_compression_results
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"nni/algorithms/compression/pytorch", "nni/algorithms/compression/pytorch",
"nni/algorithms/compression/tensorflow", "nni/algorithms/compression/tensorflow",
"nni/algorithms/compression/v2/pytorch/base/pruner.py", "nni/algorithms/compression/v2/pytorch/base/pruner.py",
"nni/algorithms/compression/v2/pytorch/pruning/amc_pruner.py",
"nni/algorithms/feature_engineering", "nni/algorithms/feature_engineering",
"nni/algorithms/hpo", "nni/algorithms/hpo",
"nni/algorithms/nas", "nni/algorithms/nas",
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path
from typing import Any, Dict, List
import torch
from .device import device
from .simple_mnist import SimpleLightningModel, SimpleTorchModel
from .utils import unfold_config_list
log_dir = Path(__file__).parent.parent / 'logs'
def create_model(model_type: str):
torch_config_list = [{'op_types': ['Linear'], 'sparsity': 0.5},
{'op_names': ['conv1', 'conv2', 'conv3'], 'sparsity': 0.5},
{'op_names': ['fc2'], 'exclude': True}]
lightning_config_list = [{'op_types': ['Linear'], 'sparsity': 0.5},
{'op_names': ['model.conv1', 'model.conv2', 'model.conv3'], 'sparsity': 0.5},
{'op_names': ['model.fc2'], 'exclude': True}]
if model_type == 'lightning':
model = SimpleLightningModel()
config_list = lightning_config_list
dummy_input = torch.rand(8, 1, 28, 28)
elif model_type == 'pytorch':
model = SimpleTorchModel().to(device)
config_list = torch_config_list
dummy_input = torch.rand(8, 1, 28, 28, device=device)
else:
raise ValueError(f'wrong model_type: {model_type}')
return model, config_list, dummy_input
def validate_masks(masks: Dict[str, Dict[str, torch.Tensor]], model: torch.nn.Module, config_list: List[Dict[str, Any]],
is_global: bool = False):
config_dict = unfold_config_list(model, config_list)
# validate if all configured layers have generated mask.
mismatched_op_names = set(config_dict.keys()).symmetric_difference(masks.keys())
assert f'mismatched op_names: {mismatched_op_names}'
target_name = 'weight'
total_masked_numel = 0
total_target_numel = 0
for module_name, target_masks in masks.items():
mask = target_masks[target_name]
assert mask.numel() == (mask == 0).sum().item() + (mask == 1).sum().item(), f'{module_name} {target_name} mask has values other than 0 and 1.'
if not is_global:
excepted_sparsity = config_dict[module_name].get('sparsity', config_dict[module_name].get('total_sparsity'))
real_sparsity = (mask == 0).sum().item() / mask.numel()
err_msg = f'{module_name} {target_name} excepted sparsity: {excepted_sparsity}, but real sparsity: {real_sparsity}'
assert excepted_sparsity * 0.9 < real_sparsity < excepted_sparsity * 1.1, err_msg
else:
total_masked_numel += (mask == 0).sum().item()
total_target_numel += mask.numel()
if is_global:
excepted_sparsity = next(iter(config_dict.values())).get('sparsity', config_dict[module_name].get('total_sparsity'))
real_sparsity = total_masked_numel / total_target_numel
err_msg = f'excepted global sparsity: {excepted_sparsity}, but real global sparsity: {real_sparsity}.'
assert excepted_sparsity * 0.9 < real_sparsity < excepted_sparsity * 1.1, err_msg
def validate_dependency_aware(model_type: str, masks: Dict[str, Dict[str, torch.Tensor]]):
# only for simple_mnist model
if model_type == 'lightning':
assert torch.equal(masks['model.conv2']['weight'].mean([1, 2, 3]), masks['model.conv3']['weight'].mean([1, 2, 3]))
if model_type == 'pytorch':
assert torch.equal(masks['conv2']['weight'].mean([1, 2, 3]), masks['conv3']['weight'].mean([1, 2, 3]))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .simple_lightning_model import SimpleLightningModel, MNISTDataModule
from .simple_torch_model import SimpleTorchModel, training_model, evaluating_model, finetuning_model
from .simple_evaluator import create_lighting_evaluator, create_pytorch_evaluator
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import nni
from nni.algorithms.compression.v2.pytorch import LightningEvaluator, TorchEvaluator
from .simple_torch_model import training_model, evaluating_model
from .simple_lightning_model import MNISTDataModule
from ..common import device
def create_lighting_evaluator() -> LightningEvaluator:
pl_trainer = nni.trace(pl.Trainer)(
accelerator='auto',
devices=1,
max_epochs=1,
max_steps=50,
logger=TensorBoardLogger(Path(__file__).parent.parent / 'lightning_logs', name="resnet"),
)
pl.Trainer()
pl_trainer.num_sanity_val_steps = 0
pl_data = nni.trace(MNISTDataModule)(data_dir='data/mnist')
evaluator = LightningEvaluator(pl_trainer, pl_data, dummy_input=torch.rand(8, 1, 28, 28))
return evaluator
def create_pytorch_evaluator(model: torch.nn.Module) -> TorchEvaluator:
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)
evaluator = TorchEvaluator(training_model, optimizer, F.nll_loss, lr_scheduler,
dummy_input=torch.rand(8, 1, 28, 28, device=device), evaluating_func=evaluating_model)
return evaluator
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import random_split, DataLoader
from torchmetrics.functional import accuracy
from torchvision.datasets import MNIST
from torchvision import transforms
import nni
from .simple_torch_model import SimpleTorchModel
class SimpleLightningModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = SimpleTorchModel()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log("train_loss", loss)
return loss
def evaluate(self, batch, stage=None):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
if stage:
self.log(f"default", loss, prog_bar=False)
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)
def validation_step(self, batch, batch_idx):
self.evaluate(batch, "val")
def test_step(self, batch, batch_idx):
self.evaluate(batch, "test")
def configure_optimizers(self):
optimizer = nni.trace(torch.optim.SGD)(
self.parameters(),
lr=0.01,
momentum=0.9,
weight_decay=5e-4,
)
scheduler_dict = {
"scheduler": nni.trace(ExponentialLR)(
optimizer,
0.1,
),
"interval": "epoch",
}
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = 'data/mnist'
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str | None = None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
if stage == "predict" or stage is None:
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from pathlib import Path
from typing import Callable
import torch
from torch.nn import Module
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from ..device import device
class SimpleTorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 16, 3)
self.bn1 = torch.nn.BatchNorm2d(16)
self.conv2 = torch.nn.Conv2d(16, 8, 3, groups=4)
self.bn2 = torch.nn.BatchNorm2d(8)
self.conv3 = torch.nn.Conv2d(16, 8, 3)
self.bn3 = torch.nn.BatchNorm2d(8)
self.fc1 = torch.nn.Linear(8 * 24 * 24, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x: torch.Tensor):
x = self.bn1(self.conv1(x))
x = self.bn2(self.conv2(x)) + self.bn3(self.conv3(x))
x = self.fc2(self.fc1(x.reshape(x.shape[0], -1)))
return F.log_softmax(x, -1)
def training_model(model: Module, optimizer: Optimizer, criterion: Callable, scheduler: _LRScheduler = None,
max_steps: int | None = None, max_epochs: int | None = None, device: torch.device = device):
model.train()
# prepare data
MNIST(root='data/mnist', train=True, download=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(root='data/mnist', train=True, transform=transform)
train_dataloader = DataLoader(mnist_train, batch_size=32)
max_epochs = max_epochs if max_epochs else 1
max_steps = max_steps if max_steps else 50
current_steps = 0
# training
for _ in range(max_epochs):
for x, y in train_dataloader:
optimizer.zero_grad()
x, y = x.to(device), y.to(device)
logits = model(x)
loss: torch.Tensor = criterion(logits, y)
loss.backward()
optimizer.step()
current_steps += 1
if max_steps and current_steps == max_steps:
return
if scheduler is not None:
scheduler.step()
def finetuning_model(model: Module):
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
training_model(model, optimizer, F.nll_loss)
def evaluating_model(model: Module, device: torch.device = device):
model.eval()
# prepare data
MNIST(root='data/mnist', train=False, download=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_test = MNIST(root='data/mnist', train=False, transform=transform)
test_dataloader = DataLoader(mnist_test, batch_size=32)
# testing
correct = 0
with torch.no_grad():
for x, y in test_dataloader:
x, y = x.to(device), y.to(device)
logits = model(x)
preds = torch.argmax(logits, dim=1)
correct += preds.eq(y.view_as(preds)).sum().item()
return correct / len(mnist_test)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
######################################################################################
# NOTE: copy from branch wrapper-refactor, will rm this file in this or next release.#
######################################################################################
from copy import deepcopy
import logging
from typing import Any, Dict, List
from torch.nn import Module
_logger = logging.getLogger(__name__)
def _unfold_op_partial_names(model: Module, config_list: List[Dict]) -> List[Dict]:
config_list = deepcopy(config_list)
full_op_names = [op_name for op_name, _ in model.named_modules()]
for config in config_list:
op_names = config.pop('op_names', [])
op_partial_names = config.pop('op_partial_names', [])
for op_partial_name in op_partial_names:
op_names.extend([op_name for op_name in full_op_names if op_partial_name in op_name])
config['op_names'] = list(set(op_names))
return config_list
def unfold_config_list(model: Module, config_list: List[Dict]) -> Dict[str, Dict[str, Any]]:
'''
Unfold config_list to op_names level, return a config_dict {op_name: config}.
'''
config_list = _unfold_op_partial_names(model=model, config_list=config_list)
config_dict = {}
for config in config_list:
for key in ['op_types', 'op_names', 'exclude_op_names']:
config.setdefault(key, [])
op_names = []
for module_name, module in model.named_modules():
module_type = type(module).__name__
if (module_type in config['op_types'] or module_name in config['op_names']) and module_name not in config['exclude_op_names']:
op_names.append(module_name)
config_template = deepcopy(config)
for key in ['op_types', 'op_names', 'exclude_op_names']:
config_template.pop(key, [])
for op_name in op_names:
if op_name in config_dict:
warn_msg = f'{op_name} duplicate definition of config, replace old config:\n' + \
f'{config_dict[op_name]}\n' + \
f'with new config:\n{config_template}\n'
_logger.warning(warn_msg)
config_dict[op_name] = deepcopy(config_template)
return config_dict
...@@ -3,193 +3,23 @@ ...@@ -3,193 +3,23 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path
from typing import Callable
import pytest import pytest
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import torch import torch
from torch.nn import Module
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import ExponentialLR, _LRScheduler
from torch.utils.data import random_split, DataLoader
from torchmetrics.functional import accuracy
from torchvision.datasets import MNIST
from torchvision import transforms
import nni
from nni.algorithms.compression.v2.pytorch.utils.evaluator import ( from nni.algorithms.compression.v2.pytorch.utils.evaluator import (
TorchEvaluator,
LightningEvaluator,
TensorHook, TensorHook,
ForwardHook, ForwardHook,
BackwardHook, BackwardHook,
) )
from ..assets.device import device
class SimpleTorchModel(torch.nn.Module): from ..assets.simple_mnist import (
def __init__(self): SimpleLightningModel,
super().__init__() SimpleTorchModel,
self.conv1 = torch.nn.Conv2d(1, 16, 3) create_lighting_evaluator,
self.bn1 = torch.nn.BatchNorm2d(16) create_pytorch_evaluator
self.conv2 = torch.nn.Conv2d(16, 8, 3, groups=4) )
self.bn2 = torch.nn.BatchNorm2d(8)
self.conv3 = torch.nn.Conv2d(16, 8, 3)
self.bn3 = torch.nn.BatchNorm2d(8)
self.fc1 = torch.nn.Linear(8 * 24 * 24, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x: torch.Tensor):
x = self.bn1(self.conv1(x))
x = self.bn2(self.conv2(x)) + self.bn3(self.conv3(x))
x = self.fc2(self.fc1(x.reshape(x.shape[0], -1)))
return F.log_softmax(x, -1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def training_model(model: Module, optimizer: Optimizer, criterion: Callable, scheduler: _LRScheduler,
max_steps: int | None = None, max_epochs: int | None = None):
model.train()
# prepare data
data_dir = Path(__file__).parent / 'data'
MNIST(data_dir, train=True, download=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(data_dir, train=True, transform=transform)
train_dataloader = DataLoader(mnist_train, batch_size=32)
max_epochs = max_epochs if max_epochs else 1
max_steps = max_steps if max_steps else 10
current_steps = 0
# training
for _ in range(max_epochs):
for x, y in train_dataloader:
optimizer.zero_grad()
x, y = x.to(device), y.to(device)
logits = model(x)
loss: torch.Tensor = criterion(logits, y)
loss.backward()
optimizer.step()
current_steps += 1
if max_steps and current_steps == max_steps:
return
scheduler.step()
def evaluating_model(model: Module):
model.eval()
# prepare data
data_dir = Path(__file__).parent / 'data'
MNIST(data_dir, train=False, download=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_test = MNIST(data_dir, train=False, transform=transform)
test_dataloader = DataLoader(mnist_test, batch_size=32)
# testing
correct = 0
with torch.no_grad():
for x, y in test_dataloader:
x, y = x.to(device), y.to(device)
logits = model(x)
preds = torch.argmax(logits, dim=1)
correct += preds.eq(y.view_as(preds)).sum().item()
return correct / len(mnist_test)
class SimpleLightningModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = SimpleTorchModel()
self.count = 0
def forward(self, x):
print(self.count)
self.count += 1
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log("train_loss", loss)
return loss
def evaluate(self, batch, stage=None):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
if stage:
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)
def validation_step(self, batch, batch_idx):
self.evaluate(batch, "val")
def test_step(self, batch, batch_idx):
self.evaluate(batch, "test")
def configure_optimizers(self):
optimizer = nni.trace(torch.optim.SGD)(
self.parameters(),
lr=0.01,
momentum=0.9,
weight_decay=5e-4,
)
scheduler_dict = {
"scheduler": nni.trace(ExponentialLR)(
optimizer,
0.1,
),
"interval": "epoch",
}
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str | None = None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
if stage == "predict" or stage is None:
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)
optimizer_before_step_flag = False optimizer_before_step_flag = False
...@@ -237,41 +67,20 @@ def assert_flags(): ...@@ -237,41 +67,20 @@ def assert_flags():
assert loss_flag, 'Evaluator patch loss failed.' assert loss_flag, 'Evaluator patch loss failed.'
def create_lighting_evaluator():
pl_model = SimpleLightningModel()
pl_trainer = nni.trace(pl.Trainer)(
max_epochs=1,
max_steps=10,
logger=TensorBoardLogger(Path(__file__).parent / 'lightning_logs', name="resnet"),
)
pl_trainer.num_sanity_val_steps = 0
pl_data = nni.trace(MNISTDataModule)(data_dir=Path(__file__).parent / 'data')
evaluator = LightningEvaluator(pl_trainer, pl_data)
evaluator._init_optimizer_helpers(pl_model)
return evaluator
def create_pytorch_evaluator():
model = SimpleTorchModel()
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)
evaluator = TorchEvaluator(training_model, optimizer, F.nll_loss, lr_scheduler, evaluating_func=evaluating_model)
evaluator._init_optimizer_helpers(model)
return evaluator
@pytest.mark.parametrize("evaluator_type", ['lightning', 'pytorch']) @pytest.mark.parametrize("evaluator_type", ['lightning', 'pytorch'])
def test_evaluator(evaluator_type: str): def test_evaluator(evaluator_type: str):
if evaluator_type == 'lightning': if evaluator_type == 'lightning':
evaluator = create_lighting_evaluator()
model = SimpleLightningModel() model = SimpleLightningModel()
evaluator = create_lighting_evaluator()
evaluator._init_optimizer_helpers(model)
evaluator.bind_model(model) evaluator.bind_model(model)
tensor_hook = TensorHook(model.model.conv1.weight, 'model.conv1.weight', tensor_hook_factory) tensor_hook = TensorHook(model.model.conv1.weight, 'model.conv1.weight', tensor_hook_factory)
forward_hook = ForwardHook(model.model.conv1, 'model.conv1', forward_hook_factory) forward_hook = ForwardHook(model.model.conv1, 'model.conv1', forward_hook_factory)
backward_hook = BackwardHook(model.model.conv1, 'model.conv1', backward_hook_factory) backward_hook = BackwardHook(model.model.conv1, 'model.conv1', backward_hook_factory)
elif evaluator_type == 'pytorch': elif evaluator_type == 'pytorch':
evaluator = create_pytorch_evaluator()
model = SimpleTorchModel().to(device) model = SimpleTorchModel().to(device)
evaluator = create_pytorch_evaluator(model)
evaluator._init_optimizer_helpers(model)
evaluator.bind_model(model) evaluator.bind_model(model)
tensor_hook = TensorHook(model.conv1.weight, 'conv1.weight', tensor_hook_factory) tensor_hook = TensorHook(model.conv1.weight, 'conv1.weight', tensor_hook_factory)
forward_hook = ForwardHook(model.conv1, 'conv1', forward_hook_factory) forward_hook = ForwardHook(model.conv1, 'conv1', forward_hook_factory)
...@@ -296,4 +105,4 @@ def test_evaluator(evaluator_type: str): ...@@ -296,4 +105,4 @@ def test_evaluator(evaluator_type: str):
evaluator.finetune() evaluator.finetune()
assert_flags() assert_flags()
assert all([len(hook.buffer) == 10 for hook in [tensor_hook, forward_hook, backward_hook]]) assert all([len(hook.buffer) == 50 for hook in [tensor_hook, forward_hook, backward_hook]])
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import pytest
import torch
import torch.nn.functional as F
import nni
from nni.compression.pytorch.pruning import (
LinearPruner,
AGPPruner,
LotteryTicketPruner,
SimulatedAnnealingPruner,
AutoCompressPruner
)
from ..assets.common import create_model, log_dir, validate_masks, validate_dependency_aware
from ..assets.device import device
from ..assets.simple_mnist import (
create_lighting_evaluator,
create_pytorch_evaluator,
training_model,
finetuning_model,
evaluating_model
)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('using_evaluator', [True, False])
@pytest.mark.parametrize('pruning_type', ['linear', 'agp', 'lottory'])
@pytest.mark.parametrize('speedup', [True, False])
def test_functional_pruner(model_type: str, using_evaluator: bool, pruning_type: str, speedup: bool):
model, config_list, dummy_input = create_model(model_type)
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
if pruning_type == 'linear':
pruner = LinearPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2,
log_dir=log_dir, keep_intermediate_result=False, evaluator=evaluator, speedup=speedup,
pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
elif pruning_type == 'agp':
pruner = AGPPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2,
log_dir=log_dir, keep_intermediate_result=False, evaluator=evaluator, speedup=speedup,
pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
elif pruning_type == 'lottory':
pruner = LotteryTicketPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2,
log_dir=log_dir, keep_intermediate_result=False, evaluator=evaluator, speedup=speedup,
pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
else:
model.to(device)
dummy_input = dummy_input.to(device)
if pruning_type == 'linear':
pruner = LinearPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2, log_dir=log_dir,
keep_intermediate_result=False, finetuner=finetuning_model, speedup=speedup, dummy_input=dummy_input,
evaluator=None, pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
elif pruning_type == 'agp':
pruner = AGPPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2, log_dir=log_dir,
keep_intermediate_result=False, finetuner=finetuning_model, speedup=speedup, dummy_input=dummy_input,
evaluator=None, pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
elif pruning_type == 'lottory':
pruner = LotteryTicketPruner(model=model, config_list=config_list, pruning_algorithm='l1', total_iteration=2, log_dir=log_dir,
keep_intermediate_result=False, finetuner=finetuning_model, speedup=speedup, dummy_input=dummy_input,
evaluator=None, pruning_params={'mode': 'dependency_aware', 'dummy_input': dummy_input})
pruner.compress()
best_task_id, best_model, best_masks, best_score, best_config_list = pruner.get_best_result()
best_model(dummy_input)
validate_masks(best_masks, best_model, config_list)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('using_evaluator', [True, False])
def test_sa_pruner(model_type: str, using_evaluator: bool):
model, config_list, dummy_input = create_model(model_type)
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
pruner = SimulatedAnnealingPruner(model=model, config_list=config_list, evaluator=evaluator, start_temperature=100,
stop_temperature=80, cool_down_rate=0.9, perturbation_magnitude=0.35, pruning_algorithm='l1',
pruning_params={}, log_dir=log_dir, keep_intermediate_result=False, speedup=False)
else:
model.to(device)
dummy_input = dummy_input.to(device)
pruner = SimulatedAnnealingPruner(model=model, config_list=config_list, evaluator=evaluating_model, start_temperature=100,
stop_temperature=80, cool_down_rate=0.9, perturbation_magnitude=0.35, pruning_algorithm='l1',
pruning_params={}, log_dir=log_dir, keep_intermediate_result=False, speedup=False)
pruner.compress()
best_task_id, best_model, best_masks, best_score, best_config_list = pruner.get_best_result()
best_model(dummy_input)
validate_masks(best_masks, best_model, config_list)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('using_evaluator', [True, False])
def test_auto_compress_pruner(model_type: str, using_evaluator: bool):
model, config_list, dummy_input = create_model(model_type)
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
admm_params = {'evaluator': evaluator, 'iterations': 2, 'training_epochs': 1, 'granularity': 'coarse-grained'}
sa_params = {'evaluator': evaluator, 'start_temperature': 100, 'stop_temperature': 80, 'pruning_algorithm': 'l1'}
pruner = AutoCompressPruner(model=model, config_list=config_list, total_iteration=2, admm_params=admm_params, sa_params=sa_params,
log_dir=log_dir, keep_intermediate_result=False, evaluator=evaluator, speedup=False)
else:
model.to(device)
dummy_input = dummy_input.to(device)
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
admm_params = {'trainer': training_model, 'traced_optimizer': optimizer, 'criterion': F.nll_loss, 'iterations': 2,
'training_epochs': 1, 'granularity': 'coarse-grained'}
sa_params = {'evaluator': evaluating_model, 'start_temperature': 100, 'stop_temperature': 80, 'pruning_algorithm': 'l1'}
pruner = AutoCompressPruner(model=model, config_list=config_list, total_iteration=2, admm_params=admm_params, sa_params=sa_params,
log_dir=log_dir, keep_intermediate_result=False, finetuner=finetuning_model, speedup=False,
dummy_input=dummy_input, evaluator=evaluating_model)
pruner.compress()
best_task_id, best_model, best_masks, best_score, best_config_list = pruner.get_best_result()
best_model(dummy_input)
validate_masks(best_masks, best_model, config_list)
# we still need AMCPruner test, but it cost a lot, will add after we have GPU pool.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import pytest
import torch
import torch.nn.functional as F
import nni
from nni.compression.pytorch.pruning import (
LevelPruner,
L1NormPruner,
L2NormPruner,
SlimPruner,
FPGMPruner,
ActivationAPoZRankPruner,
ActivationMeanRankPruner,
TaylorFOWeightPruner,
ADMMPruner,
MovementPruner
)
from ..assets.device import device
from ..assets.simple_mnist import (
create_lighting_evaluator,
create_pytorch_evaluator,
training_model
)
from ..assets.common import create_model, validate_masks, validate_dependency_aware
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
def test_level_pruner(model_type: str):
model, config_list, dummy_input = create_model(model_type)
pruner = LevelPruner(model=model, config_list=config_list)
_, masks = pruner.compress()
model(dummy_input)
pruner._unwrap_model()
validate_masks(masks, model, config_list)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('pruning_type', ['l1', 'l2', 'fpgm'])
@pytest.mark.parametrize('mode', ['normal', 'dependency_aware'])
def test_norm_pruner(model_type: str, pruning_type: str, mode: str):
model, config_list, dummy_input = create_model(model_type)
if pruning_type == 'l1':
pruner = L1NormPruner(model=model, config_list=config_list, mode=mode, dummy_input=dummy_input)
elif pruning_type == 'l2':
pruner = L2NormPruner(model=model, config_list=config_list, mode=mode, dummy_input=dummy_input)
elif pruning_type == 'fpgm':
pruner = FPGMPruner(model=model, config_list=config_list, mode=mode, dummy_input=dummy_input)
else:
raise ValueError(f'wrong norm: {pruning_type}')
_, masks = pruner.compress()
model(dummy_input)
pruner._unwrap_model()
validate_masks(masks, model, config_list)
if mode == 'dependency_aware':
validate_dependency_aware(model_type, masks)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('using_evaluator', [True, False])
@pytest.mark.parametrize('mode', ['global', 'normal'])
def test_slim_pruner(model_type: str, using_evaluator: bool, mode: str):
model, _, dummy_input = create_model(model_type)
config_list = [{'op_types': ['BatchNorm2d'], 'total_sparsity': 0.5}]
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
pruner = SlimPruner(model=model, config_list=config_list, evaluator=evaluator, training_epochs=1, scale=0.0001, mode=mode)
else:
model = model.to(device)
dummy_input = dummy_input.to(device)
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
pruner = SlimPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer,
criterion=F.nll_loss, training_epochs=1, scale=0.0001, mode=mode)
_, masks = pruner.compress()
model(dummy_input)
pruner._unwrap_model()
validate_masks(masks, model, config_list, is_global=(mode == 'global'))
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('pruning_type', ['apoz', 'mean', 'taylor'])
@pytest.mark.parametrize('using_evaluator', [True, False])
@pytest.mark.parametrize('mode', ['normal', 'dependency_aware'])
def test_hook_based_pruner(model_type: str, pruning_type: str, using_evaluator: bool, mode: str):
model, config_list, dummy_input = create_model(model_type)
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
if pruning_type == 'apoz':
pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, evaluator=evaluator, training_steps=20,
activation='relu', mode=mode, dummy_input=dummy_input)
elif pruning_type == 'mean':
pruner = ActivationMeanRankPruner(model=model, config_list=config_list, evaluator=evaluator, training_steps=20,
activation='relu', mode=mode, dummy_input=dummy_input)
elif pruning_type == 'taylor':
pruner = TaylorFOWeightPruner(model=model, config_list=config_list, evaluator=evaluator, training_steps=20,
mode=mode, dummy_input=dummy_input)
else:
model = model.to(device)
dummy_input = dummy_input.to(device)
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
if pruning_type == 'apoz':
pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer,
criterion=F.nll_loss, training_batches=20, activation='relu', mode=mode, dummy_input=dummy_input)
elif pruning_type == 'mean':
pruner = ActivationMeanRankPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer,
criterion=F.nll_loss, training_batches=20, activation='relu', mode=mode, dummy_input=dummy_input)
elif pruning_type == 'taylor':
pruner = TaylorFOWeightPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer,
criterion=F.nll_loss, training_batches=20, mode=mode, dummy_input=dummy_input)
_, masks = pruner.compress()
model(dummy_input)
pruner._unwrap_model()
validate_masks(masks, model, config_list)
if mode == 'dependency_aware':
validate_dependency_aware(model_type, masks)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('using_evaluator', [True, False])
@pytest.mark.parametrize('granularity', ['fine-grained', 'coarse-grained'])
def test_admm_pruner(model_type: str, using_evaluator: bool, granularity: str):
model, config_list, dummy_input = create_model(model_type)
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
pruner = ADMMPruner(model=model, config_list=config_list, evaluator=evaluator, iterations=2, training_epochs=1, granularity=granularity)
else:
model = model.to(device)
dummy_input = dummy_input.to(device)
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
pruner = ADMMPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer, criterion=F.nll_loss,
iterations=2, training_epochs=1, granularity=granularity)
_, masks = pruner.compress()
model(dummy_input)
pruner._unwrap_model()
validate_masks(masks, model, config_list)
@pytest.mark.parametrize('model_type', ['lightning', 'pytorch'])
@pytest.mark.parametrize('using_evaluator', [True, False])
def test_movement_pruner(model_type: str, using_evaluator: bool):
model, config_list, dummy_input = create_model(model_type)
if using_evaluator:
evaluator = create_lighting_evaluator() if model_type == 'lightning' else create_pytorch_evaluator(model)
pruner = MovementPruner(model=model, config_list=config_list, evaluator=evaluator, training_epochs=1, warm_up_step=10, cool_down_beginning_step=40)
else:
model = model.to(device)
dummy_input = dummy_input.to(device)
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
pruner = MovementPruner(model=model, config_list=config_list, trainer=training_model, traced_optimizer=optimizer, criterion=F.nll_loss,
training_epochs=1, warm_up_step=10, cool_down_beginning_step=40)
_, masks = pruner.compress()
model(dummy_input)
pruner._unwrap_model()
validate_masks(masks, model, config_list)
...@@ -8,14 +8,20 @@ import torch.nn.functional as F ...@@ -8,14 +8,20 @@ import torch.nn.functional as F
import nni import nni
from nni.algorithms.compression.v2.pytorch.base import Pruner from nni.algorithms.compression.v2.pytorch.base import Pruner
# TODO: remove in nni v3.0.
from nni.algorithms.compression.v2.pytorch.pruning.tools import ( from nni.algorithms.compression.v2.pytorch.pruning.tools import (
WeightDataCollector, WeightDataCollector,
WeightTrainerBasedDataCollector, WeightTrainerBasedDataCollector,
SingleHookTrainerBasedDataCollector SingleHookTrainerBasedDataCollector
) )
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
TargetDataCollector,
EvaluatorBasedTargetDataCollector,
EvaluatorBasedHookDataCollector
)
from nni.algorithms.compression.v2.pytorch.pruning.tools import ( from nni.algorithms.compression.v2.pytorch.pruning.tools import (
NormMetricsCalculator, NormMetricsCalculator,
MultiDataNormMetricsCalculator, HookDataNormMetricsCalculator,
DistMetricsCalculator, DistMetricsCalculator,
APoZRankMetricsCalculator, APoZRankMetricsCalculator,
MeanRankMetricsCalculator MeanRankMetricsCalculator
...@@ -84,7 +90,7 @@ class PruningToolsTestCase(unittest.TestCase): ...@@ -84,7 +90,7 @@ class PruningToolsTestCase(unittest.TestCase):
# Test WeightDataCollector # Test WeightDataCollector
data_collector = WeightDataCollector(pruner) data_collector = WeightDataCollector(pruner)
data = data_collector.collect() data = data_collector.collect()
assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]) for module_name in ['conv1', 'conv2']) assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]['weight']) for module_name in ['conv1', 'conv2'])
# Test WeightTrainerBasedDataCollector # Test WeightTrainerBasedDataCollector
def opt_after(): def opt_after():
...@@ -94,8 +100,8 @@ class PruningToolsTestCase(unittest.TestCase): ...@@ -94,8 +100,8 @@ class PruningToolsTestCase(unittest.TestCase):
optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model)) optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model))
data_collector = WeightTrainerBasedDataCollector(pruner, trainer, optimizer_helper, criterion, 1, opt_after_tasks=[opt_after]) data_collector = WeightTrainerBasedDataCollector(pruner, trainer, optimizer_helper, criterion, 1, opt_after_tasks=[opt_after])
data = data_collector.collect() data = data_collector.collect()
assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]) for module_name in ['conv1', 'conv2']) assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]['weight']) for module_name in ['conv1', 'conv2'])
assert all(t.numel() == (t == 1).type_as(t).sum().item() for t in data.values()) assert all(t['weight'].numel() == (t['weight'] == 1).type_as(t['weight']).sum().item() for t in data.values())
# Test SingleHookTrainerBasedDataCollector # Test SingleHookTrainerBasedDataCollector
def _collector(buffer, weight_tensor): def _collector(buffer, weight_tensor):
...@@ -109,73 +115,73 @@ class PruningToolsTestCase(unittest.TestCase): ...@@ -109,73 +115,73 @@ class PruningToolsTestCase(unittest.TestCase):
optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model)) optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model))
data_collector = SingleHookTrainerBasedDataCollector(pruner, trainer, optimizer_helper, criterion, 2, collector_infos=[collector_info]) data_collector = SingleHookTrainerBasedDataCollector(pruner, trainer, optimizer_helper, criterion, 2, collector_infos=[collector_info])
data = data_collector.collect() data = data_collector.collect()
assert all(len(t) == 2 for t in data.values()) assert all(len(t['weight']) == 2 for t in data.values())
def test_metrics_calculator(self): def test_metrics_calculator(self):
# Test NormMetricsCalculator # Test NormMetricsCalculator
metrics_calculator = NormMetricsCalculator(p=2, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back')) metrics_calculator = NormMetricsCalculator(p=2, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
data = { data = {
'1': torch.ones(3, 3, 3), '1': {'target_name': torch.ones(3, 3, 3)},
'2': torch.ones(4, 4) * 2 '2': {'target_name': torch.ones(4, 4) * 2}
} }
result = { result = {
'1': torch.ones(3) * 3, '1': {'target_name': torch.ones(3) * 3},
'2': torch.ones(4) * 4 '2': {'target_name': torch.ones(4) * 4}
} }
metrics = metrics_calculator.calculate_metrics(data) metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items()) assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
# Test DistMetricsCalculator # Test DistMetricsCalculator
metrics_calculator = DistMetricsCalculator(p=2, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back')) metrics_calculator = DistMetricsCalculator(p=2, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
data = { data = {
'1': torch.tensor([[1, 2], [4, 6]], dtype=torch.float32), '1': {'target_name': torch.tensor([[1, 2], [4, 6]], dtype=torch.float32)},
'2': torch.tensor([[0, 0], [1, 1]], dtype=torch.float32) '2': {'target_name': torch.tensor([[0, 0], [1, 1]], dtype=torch.float32)}
} }
result = { result = {
'1': torch.tensor([5, 5], dtype=torch.float32), '1': {'target_name': torch.tensor([5, 5], dtype=torch.float32)},
'2': torch.sqrt(torch.tensor([2, 2], dtype=torch.float32)) '2': {'target_name': torch.sqrt(torch.tensor([2, 2], dtype=torch.float32))}
} }
metrics = metrics_calculator.calculate_metrics(data) metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items()) assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
# Test MultiDataNormMetricsCalculator # Test HookDataNormMetricsCalculator
metrics_calculator = MultiDataNormMetricsCalculator(p=1, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back')) metrics_calculator = HookDataNormMetricsCalculator(p=1, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
data = { data = {
'1': [2, torch.ones(3, 3, 3) * 2], '1': {'target_name': [2, torch.ones(3, 3, 3) * 2]},
'2': [2, torch.ones(4, 4) * 2] '2': {'target_name': [2, torch.ones(4, 4) * 2]}
} }
result = { result = {
'1': torch.ones(3) * 18, '1': {'target_name': torch.ones(3) * 18},
'2': torch.ones(4) * 8 '2': {'target_name': torch.ones(4) * 8}
} }
metrics = metrics_calculator.calculate_metrics(data) metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items()) assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
# Test APoZRankMetricsCalculator # Test APoZRankMetricsCalculator
metrics_calculator = APoZRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back')) metrics_calculator = APoZRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back'))
data = { data = {
'1': [2, torch.tensor([[1, 1], [1, 1]], dtype=torch.float32)], '1': {'target_name': [2, torch.tensor([[1, 1], [1, 1]], dtype=torch.float32)]},
'2': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)] '2': {'target_name': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]}
} }
result = { result = {
'1': torch.tensor([0.5, 0.5], dtype=torch.float32), '1': {'target_name': torch.tensor([0.5, 0.5], dtype=torch.float32)},
'2': torch.tensor([1, 1, 0.75], dtype=torch.float32) '2': {'target_name': torch.tensor([1, 1, 0.75], dtype=torch.float32)}
} }
metrics = metrics_calculator.calculate_metrics(data) metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items()) assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
# Test MeanRankMetricsCalculator # Test MeanRankMetricsCalculator
metrics_calculator = MeanRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back')) metrics_calculator = MeanRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back'))
data = { data = {
'1': [2, torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)], '1': {'target_name': [2, torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)]},
'2': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)] '2': {'target_name': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]}
} }
result = { result = {
'1': torch.tensor([0.25, 0.25], dtype=torch.float32), '1': {'target_name': torch.tensor([0.25, 0.25], dtype=torch.float32)},
'2': torch.tensor([0, 0, 0.25], dtype=torch.float32) '2': {'target_name': torch.tensor([0, 0, 0.25], dtype=torch.float32)}
} }
metrics = metrics_calculator.calculate_metrics(data) metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items()) assert all(torch.equal(result[k]['target_name'], v['target_name']) for k, v in metrics.items())
def test_sparsity_allocator(self): def test_sparsity_allocator(self):
# Test NormalSparsityAllocator # Test NormalSparsityAllocator
...@@ -183,8 +189,8 @@ class PruningToolsTestCase(unittest.TestCase): ...@@ -183,8 +189,8 @@ class PruningToolsTestCase(unittest.TestCase):
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
pruner = Pruner(model, config_list) pruner = Pruner(model, config_list)
metrics = { metrics = {
'conv1': torch.rand(5, 1, 5, 5), 'conv1': {'weight': torch.rand(5, 1, 5, 5)},
'conv2': torch.rand(10, 5, 5, 5) 'conv2': {'weight': torch.rand(10, 5, 5, 5)}
} }
sparsity_allocator = NormalSparsityAllocator(pruner) sparsity_allocator = NormalSparsityAllocator(pruner)
masks = sparsity_allocator.generate_sparsity(metrics) masks = sparsity_allocator.generate_sparsity(metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment