Unverified Commit 9b0bc378 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[#3507 follow up] update doc (#3688)

parent 916267d8
...@@ -54,11 +54,11 @@ To enable the dependency-aware mode for ``L1FilterPruner``\ : ...@@ -54,11 +54,11 @@ To enable the dependency-aware mode for ``L1FilterPruner``\ :
# for FPGMPruner # for FPGMPruner
# pruner = FPGMPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input) # pruner = FPGMPruner(model, config_list, dependency_aware=True, dummy_input=dummy_input)
# for ActivationAPoZRankFilterPruner # for ActivationAPoZRankFilterPruner
# pruner = ActivationAPoZRankFilterPruner(model, config_list, statistics_batch_num=1, , dependency_aware=True, dummy_input=dummy_input) # pruner = ActivationAPoZRankFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1, dependency_aware=True, dummy_input=dummy_input)
# for ActivationMeanRankFilterPruner # for ActivationMeanRankFilterPruner
# pruner = ActivationMeanRankFilterPruner(model, config_list, statistics_batch_num=1, dependency_aware=True, dummy_input=dummy_input) # pruner = ActivationMeanRankFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1, dependency_aware=True, dummy_input=dummy_input)
# for TaylorFOWeightFilterPruner # for TaylorFOWeightFilterPruner
# pruner = TaylorFOWeightFilterPruner(model, config_list, statistics_batch_num=1, dependency_aware=True, dummy_input=dummy_input) # pruner = TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1, dependency_aware=True, dummy_input=dummy_input)
pruner.compress() pruner.compress()
......
...@@ -29,8 +29,7 @@ Compressor is the base class for pruner and quntizer, it provides a unified inte ...@@ -29,8 +29,7 @@ Compressor is the base class for pruner and quntizer, it provides a unified inte
'op_types': ['Conv2d', 'Linear'], 'op_types': ['Conv2d', 'Linear'],
}] }]
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) pruner = LevelPruner(model, configure_list)
pruner = LevelPruner(model, configure_list, optimizer)
model = pruner.compress() model = pruner.compress()
# model is ready for pruning, now start finetune the model, # model is ready for pruning, now start finetune the model,
...@@ -103,7 +102,8 @@ Users can also remove this collector like this: ...@@ -103,7 +102,8 @@ Users can also remove this collector like this:
Pruner Pruner
------ ------
A pruner receives ``model``\ , ``config_list`` and ``optimizer`` as arguments. It prunes the model per the ``config_list`` during training loop by adding a hook on ``optimizer.step()``. A pruner receives ``model`` , ``config_list`` as arguments.
Some pruners like ``TaylorFOWeightFilter Pruner`` prune the model per the ``config_list`` during training loop by adding a hook on ``optimizer.step()``.
Pruner class is a subclass of Compressor, so it contains everything in the Compressor class and some additional components only for pruning, it contains: Pruner class is a subclass of Compressor, so it contains everything in the Compressor class and some additional components only for pruning, it contains:
......
...@@ -71,7 +71,7 @@ PyTorch code ...@@ -71,7 +71,7 @@ PyTorch code
from nni.algorithms.compression.pytorch.pruning import SlimPruner from nni.algorithms.compression.pytorch.pruning import SlimPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['BatchNorm2d'] }] config_list = [{ 'sparsity': 0.8, 'op_types': ['BatchNorm2d'] }]
pruner = SlimPruner(model, config_list) pruner = SlimPruner(model, config_list, optimizer, trainer, criterion)
pruner.compress() pruner.compress()
User configuration for Slim Pruner User configuration for Slim Pruner
...@@ -269,7 +269,7 @@ PyTorch code ...@@ -269,7 +269,7 @@ PyTorch code
'sparsity': 0.5, 'sparsity': 0.5,
'op_types': ['Conv2d'] 'op_types': ['Conv2d']
}] }]
pruner = ActivationAPoZRankFilterPruner(model, config_list, statistics_batch_num=1) pruner = ActivationAPoZRankFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1)
pruner.compress() pruner.compress()
Note: ActivationAPoZRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the ``op_types`` field supports only convolutional layers. Note: ActivationAPoZRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the ``op_types`` field supports only convolutional layers.
...@@ -304,7 +304,7 @@ PyTorch code ...@@ -304,7 +304,7 @@ PyTorch code
'sparsity': 0.5, 'sparsity': 0.5,
'op_types': ['Conv2d'] 'op_types': ['Conv2d']
}] }]
pruner = ActivationMeanRankFilterPruner(model, config_list, statistics_batch_num=1) pruner = ActivationMeanRankFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1)
pruner.compress() pruner.compress()
Note: ActivationMeanRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the ``op_types`` field supports only convolutional layers. Note: ActivationMeanRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the ``op_types`` field supports only convolutional layers.
...@@ -344,7 +344,7 @@ PyTorch code ...@@ -344,7 +344,7 @@ PyTorch code
'sparsity': 0.5, 'sparsity': 0.5,
'op_types': ['Conv2d'] 'op_types': ['Conv2d']
}] }]
pruner = TaylorFOWeightFilterPruner(model, config_list, statistics_batch_num=1) pruner = TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1)
pruner.compress() pruner.compress()
User configuration for TaylorFOWeightFilter Pruner User configuration for TaylorFOWeightFilter Pruner
...@@ -389,7 +389,7 @@ PyTorch code ...@@ -389,7 +389,7 @@ PyTorch code
# optimizer.step(), so an optimizer is required to prune the model. # optimizer.step(), so an optimizer is required to prune the model.
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
pruner = AGPPruner(model, config_list, optimizer, pruning_algorithm='level') pruner = AGPPruner(model, config_list, optimizer, trainer, criterion, pruning_algorithm='level')
pruner.compress() pruner.compress()
AGP pruner uses ``LevelPruner`` algorithms to prune the weight by default, however you can set ``pruning_algorithm`` parameter to other values to use other pruning algorithms: AGP pruner uses ``LevelPruner`` algorithms to prune the weight by default, however you can set ``pruning_algorithm`` parameter to other values to use other pruning algorithms:
...@@ -404,14 +404,6 @@ AGP pruner uses ``LevelPruner`` algorithms to prune the weight by default, howev ...@@ -404,14 +404,6 @@ AGP pruner uses ``LevelPruner`` algorithms to prune the weight by default, howev
* ``apoz``\ : ActivationAPoZRankFilterPruner * ``apoz``\ : ActivationAPoZRankFilterPruner
* ``mean_activation``\ : ActivationMeanRankFilterPruner * ``mean_activation``\ : ActivationMeanRankFilterPruner
You should add code below to update epoch number when you finish one epoch in your training code.
PyTorch code
.. code-block:: python
pruner.update_epoch(epoch)
User configuration for AGP Pruner User configuration for AGP Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...@@ -620,7 +612,7 @@ PyTorch code ...@@ -620,7 +612,7 @@ PyTorch code
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
'op_names': ['conv2'] 'op_names': ['conv2']
}] }]
pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=30, epochs=5) pruner = ADMMPruner(model, config_list, trainer, num_iterations=30, epochs_per_iteration=5)
pruner.compress() pruner.compress()
You can view :githublink:`example <examples/model_compress/pruning/auto_pruners_torch.py>` for more information. You can view :githublink:`example <examples/model_compress/pruning/auto_pruners_torch.py>` for more information.
......
...@@ -31,17 +31,16 @@ The specification of configuration can be found `here <./Tutorial.rst#specify-th ...@@ -31,17 +31,16 @@ The specification of configuration can be found `here <./Tutorial.rst#specify-th
Step2. Choose a pruner and compress the model Step2. Choose a pruner and compress the model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
First instantiate the chosen pruner with your model and configuration as arguments, then invoke ``compress()`` to compress your model. Note that, some algorithms may check gradients for compressing, so we also define an optimizer and pass it to the pruner. First instantiate the chosen pruner with your model and configuration as arguments, then invoke ``compress()`` to compress your model. Note that, some algorithms may check gradients for compressing, so we may also define an optimizer and pass it to the pruner.
.. code-block:: python .. code-block:: python
from nni.algorithms.compression.pytorch.pruning import LevelPruner from nni.algorithms.compression.pytorch.pruning import LevelPruner
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.01) pruner = LevelPruner(model, config_list)
pruner = LevelPruner(model, config_list, optimizer_finetune)
model = pruner.compress() model = pruner.compress()
Then, you can train your model using traditional training approach (e.g., SGD), pruning is applied transparently during the training. Some pruners (e.g., L1FilterPruner, FPGMPruner) prune once at the beginning, the following training can be seen as fine-tune. Some pruners (e.g., AGPPruner) prune your model iteratively, the masks are adjusted epoch by epoch during training. Some pruners (e.g., L1FilterPruner, FPGMPruner) prune once, some pruners (e.g., AGPPruner) prune your model iteratively, the masks are adjusted epoch by epoch during training.
Note that, ``pruner.compress`` simply adds masks on model weights, it does not include fine-tuning logic. If users want to fine tune the compressed model, they need to write the fine tune logic by themselves after ``pruner.compress``. Note that, ``pruner.compress`` simply adds masks on model weights, it does not include fine-tuning logic. If users want to fine tune the compressed model, they need to write the fine tune logic by themselves after ``pruner.compress``.
......
...@@ -231,10 +231,10 @@ def main(args): ...@@ -231,10 +231,10 @@ def main(args):
kw_args['criterion'] = criterion kw_args['criterion'] = criterion
if args.pruner in ('mean_activation', 'apoz', 'taylorfo'): if args.pruner in ('mean_activation', 'apoz', 'taylorfo'):
kw_args['sparsity_training_epochs'] = 1 kw_args['sparsifying_training_batches'] = 1
if args.pruner == 'slim': if args.pruner == 'slim':
kw_args['sparsity_training_epochs'] = 5 kw_args['sparsifying_training_epochs'] = 5
if args.pruner == 'agp': if args.pruner == 'agp':
kw_args['pruning_algorithm'] = 'l1' kw_args['pruning_algorithm'] = 'l1'
......
...@@ -34,6 +34,9 @@ class AutoCompressPruner(Pruner): ...@@ -34,6 +34,9 @@ class AutoCompressPruner(Pruner):
Function used for the first subproblem of ADMM Pruner. Function used for the first subproblem of ADMM Pruner.
Users should write this function as a normal function to train the Pytorch model Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch` as function arguments. and include `model, optimizer, criterion, epoch` as function arguments.
criterion: function
Function used to calculate the loss between the target and the output. By default, we use CrossEntropyLoss.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
evaluator : function evaluator : function
function to evaluate the pruned model. function to evaluate the pruned model.
This function should include `model` as the only parameter, and returns a scalar value. This function should include `model` as the only parameter, and returns a scalar value.
...@@ -80,7 +83,7 @@ class AutoCompressPruner(Pruner): ...@@ -80,7 +83,7 @@ class AutoCompressPruner(Pruner):
PATH to store temporary experiment data. PATH to store temporary experiment data.
""" """
def __init__(self, model, config_list, trainer, criterion, evaluator, dummy_input, def __init__(self, model, config_list, trainer, evaluator, dummy_input, criterion=torch.nn.CrossEntropyLoss(),
num_iterations=3, optimize_mode='maximize', base_algo='l1', num_iterations=3, optimize_mode='maximize', base_algo='l1',
# SimulatedAnnealing related # SimulatedAnnealing related
start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35, start_temperature=100, stop_temperature=20, cool_down_rate=0.9, perturbation_magnitude=0.35,
......
...@@ -40,6 +40,7 @@ class IterativePruner(DependencyAwarePruner): ...@@ -40,6 +40,7 @@ class IterativePruner(DependencyAwarePruner):
and include `model, optimizer, criterion, epoch` as function arguments. and include `model, optimizer, criterion, epoch` as function arguments.
criterion: function criterion: function
Function used to calculate the loss between the target and the output. Function used to calculate the loss between the target and the output.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
num_iterations: int num_iterations: int
Total number of iterations in pruning process. We will calculate mask at the end of an iteration. Total number of iterations in pruning process. We will calculate mask at the end of an iteration.
epochs_per_iteration: Union[int, list] epochs_per_iteration: Union[int, list]
...@@ -59,8 +60,11 @@ class IterativePruner(DependencyAwarePruner): ...@@ -59,8 +60,11 @@ class IterativePruner(DependencyAwarePruner):
assert len(epochs_per_iteration) == num_iterations, 'num_iterations should equal to the length of epochs_per_iteration' assert len(epochs_per_iteration) == num_iterations, 'num_iterations should equal to the length of epochs_per_iteration'
self.epochs_per_iteration = epochs_per_iteration self.epochs_per_iteration = epochs_per_iteration
else: else:
assert num_iterations > 0, 'num_iterations should >= 1'
self.epochs_per_iteration = [epochs_per_iteration] * num_iterations self.epochs_per_iteration = [epochs_per_iteration] * num_iterations
self._validate_iteration_params()
self._trainer = trainer self._trainer = trainer
self._criterion = criterion self._criterion = criterion
...@@ -68,6 +72,9 @@ class IterativePruner(DependencyAwarePruner): ...@@ -68,6 +72,9 @@ class IterativePruner(DependencyAwarePruner):
for wrapper in self.get_modules_wrapper(): for wrapper in self.get_modules_wrapper():
wrapper.if_calculated = False wrapper.if_calculated = False
def _validate_iteration_params(self):
assert all(num >= 0 for num in self.epochs_per_iteration), 'all epoch number need >= 0'
def compress(self): def compress(self):
training = self.bound_model.training training = self.bound_model.training
self.bound_model.train() self.bound_model.train()
...@@ -75,6 +82,10 @@ class IterativePruner(DependencyAwarePruner): ...@@ -75,6 +82,10 @@ class IterativePruner(DependencyAwarePruner):
self._fresh_calculated() self._fresh_calculated()
for epoch in range(epochs_num): for epoch in range(epochs_num):
self._trainer(self.bound_model, optimizer=self.optimizer, criterion=self._criterion, epoch=epoch) self._trainer(self.bound_model, optimizer=self.optimizer, criterion=self._criterion, epoch=epoch)
# NOTE: workaround for statistics_batch_num bigger than max batch number in one epoch, need refactor
if hasattr(self.masker, 'statistics_batch_num') and hasattr(self, 'iterations'):
if self.iterations < self.masker.statistics_batch_num:
self.iterations = self.masker.statistics_batch_num
self.update_mask() self.update_mask()
self.bound_model.train(training) self.bound_model.train(training)
...@@ -97,6 +108,7 @@ class AGPPruner(IterativePruner): ...@@ -97,6 +108,7 @@ class AGPPruner(IterativePruner):
Function to train the model Function to train the model
criterion: function criterion: function
Function used to calculate the loss between the target and the output. Function used to calculate the loss between the target and the output.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
num_iterations: int num_iterations: int
Total number of iterations in pruning process. We will calculate mask at the end of an iteration. Total number of iterations in pruning process. We will calculate mask at the end of an iteration.
epochs_per_iteration: int epochs_per_iteration: int
...@@ -245,6 +257,7 @@ class ADMMPruner(IterativePruner): ...@@ -245,6 +257,7 @@ class ADMMPruner(IterativePruner):
and include `model, optimizer, criterion, epoch` as function arguments. and include `model, optimizer, criterion, epoch` as function arguments.
criterion: function criterion: function
Function used to calculate the loss between the target and the output. By default, we use CrossEntropyLoss in ADMMPruner. Function used to calculate the loss between the target and the output. By default, we use CrossEntropyLoss in ADMMPruner.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
num_iterations: int num_iterations: int
Total number of iterations in pruning process. We will calculate mask after we finish all iterations in ADMMPruner. Total number of iterations in pruning process. We will calculate mask after we finish all iterations in ADMMPruner.
epochs_per_iteration: int epochs_per_iteration: int
...@@ -254,7 +267,6 @@ class ADMMPruner(IterativePruner): ...@@ -254,7 +267,6 @@ class ADMMPruner(IterativePruner):
base_algo : str base_algo : str
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among
the ops, the assigned `base_algo` is used to decide which filters/channels/weights to prune. the ops, the assigned `base_algo` is used to decide which filters/channels/weights to prune.
""" """
def __init__(self, model, config_list, trainer, criterion=torch.nn.CrossEntropyLoss(), def __init__(self, model, config_list, trainer, criterion=torch.nn.CrossEntropyLoss(),
...@@ -396,7 +408,8 @@ class SlimPruner(IterativePruner): ...@@ -396,7 +408,8 @@ class SlimPruner(IterativePruner):
and include `model, optimizer, criterion, epoch` as function arguments. and include `model, optimizer, criterion, epoch` as function arguments.
criterion : function criterion : function
Function used to calculate the loss between the target and the output. Function used to calculate the loss between the target and the output.
sparsity_training_epochs: int For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
sparsifying_training_epochs: int
The number of channel sparsity regularization training epochs before pruning. The number of channel sparsity regularization training epochs before pruning.
scale : float scale : float
Penalty parameters for sparsification. Penalty parameters for sparsification.
...@@ -413,10 +426,10 @@ class SlimPruner(IterativePruner): ...@@ -413,10 +426,10 @@ class SlimPruner(IterativePruner):
should on the same device with the model. should on the same device with the model.
""" """
def __init__(self, model, config_list, optimizer, trainer, criterion, sparsity_training_epochs=10, scale=0.0001, def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_epochs=10, scale=0.0001,
dependency_aware=False, dummy_input=None): dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='slim', trainer=trainer, criterion=criterion, super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='slim', trainer=trainer, criterion=criterion,
num_iterations=1, epochs_per_iteration=sparsity_training_epochs, dependency_aware=dependency_aware, num_iterations=1, epochs_per_iteration=sparsifying_training_epochs, dependency_aware=dependency_aware,
dummy_input=dummy_input) dummy_input=dummy_input)
self.scale = scale self.scale = scale
self.patch_optimizer_before(self._callback) self.patch_optimizer_before(self._callback)
...@@ -459,8 +472,9 @@ class TaylorFOWeightFilterPruner(IterativePruner): ...@@ -459,8 +472,9 @@ class TaylorFOWeightFilterPruner(IterativePruner):
and include `model, optimizer, criterion, epoch` as function arguments. and include `model, optimizer, criterion, epoch` as function arguments.
criterion : function criterion : function
Function used to calculate the loss between the target and the output. Function used to calculate the loss between the target and the output.
sparsity_training_epochs: int For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
The number of epochs to collect the contributions. sparsifying_training_batches: int
The number of batches to collect the contributions. Note that the number need to be less than the maximum batch number in one epoch.
dependency_aware: bool dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or prune the model according to the l2-norm of weights and the channel-dependency or
...@@ -472,14 +486,14 @@ class TaylorFOWeightFilterPruner(IterativePruner): ...@@ -472,14 +486,14 @@ class TaylorFOWeightFilterPruner(IterativePruner):
dummy_input : torch.Tensor dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model. should on the same device with the model.
""" """
def __init__(self, model, config_list, optimizer, trainer, criterion, sparsity_training_epochs=1, dependency_aware=False, def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1,
dummy_input=None): dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer, super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer,
criterion=criterion, num_iterations=1, epochs_per_iteration=sparsity_training_epochs, criterion=criterion, statistics_batch_num=sparsifying_training_batches, num_iterations=1,
dependency_aware=dependency_aware, dummy_input=dummy_input) epochs_per_iteration=1, dependency_aware=dependency_aware,
dummy_input=dummy_input)
def _supported_dependency_aware(self): def _supported_dependency_aware(self):
return True return True
...@@ -503,10 +517,11 @@ class ActivationAPoZRankFilterPruner(IterativePruner): ...@@ -503,10 +517,11 @@ class ActivationAPoZRankFilterPruner(IterativePruner):
and include `model, optimizer, criterion, epoch` as function arguments. and include `model, optimizer, criterion, epoch` as function arguments.
criterion : function criterion : function
Function used to calculate the loss between the target and the output. Function used to calculate the loss between the target and the output.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
activation: str activation: str
The activation type. The activation type.
sparsity_training_epochs: int sparsifying_training_batches: int
The number of epochs to statistic the activation. The number of batches to collect the contributions. Note that the number need to be less than the maximum batch number in one epoch.
dependency_aware: bool dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or prune the model according to the l2-norm of weights and the channel-dependency or
...@@ -522,10 +537,11 @@ class ActivationAPoZRankFilterPruner(IterativePruner): ...@@ -522,10 +537,11 @@ class ActivationAPoZRankFilterPruner(IterativePruner):
""" """
def __init__(self, model, config_list, optimizer, trainer, criterion, activation='relu', def __init__(self, model, config_list, optimizer, trainer, criterion, activation='relu',
sparsity_training_epochs=1, dependency_aware=False, dummy_input=None): sparsifying_training_batches=1, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, trainer=trainer, super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, trainer=trainer,
criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input, criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input,
activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs) activation=activation, statistics_batch_num=sparsifying_training_batches, num_iterations=1,
epochs_per_iteration=1)
self.patch_optimizer(self.update_mask) self.patch_optimizer(self.update_mask)
def _supported_dependency_aware(self): def _supported_dependency_aware(self):
...@@ -550,10 +566,11 @@ class ActivationMeanRankFilterPruner(IterativePruner): ...@@ -550,10 +566,11 @@ class ActivationMeanRankFilterPruner(IterativePruner):
and include `model, optimizer, criterion, epoch` as function arguments. and include `model, optimizer, criterion, epoch` as function arguments.
criterion : function criterion : function
Function used to calculate the loss between the target and the output. Function used to calculate the loss between the target and the output.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
activation: str activation: str
The activation type. The activation type.
sparsity_training_epochs: int sparsifying_training_batches: int
The number of batches to statistic the activation. The number of batches to collect the contributions. Note that the number need to be less than the maximum batch number in one epoch.
dependency_aware: bool dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or prune the model according to the l2-norm of weights and the channel-dependency or
...@@ -568,10 +585,11 @@ class ActivationMeanRankFilterPruner(IterativePruner): ...@@ -568,10 +585,11 @@ class ActivationMeanRankFilterPruner(IterativePruner):
""" """
def __init__(self, model, config_list, optimizer, trainer, criterion, activation='relu', def __init__(self, model, config_list, optimizer, trainer, criterion, activation='relu',
sparsity_training_epochs=1, dependency_aware=False, dummy_input=None): sparsifying_training_batches=1, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, trainer=trainer, super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, trainer=trainer,
criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input, criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input,
activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs) activation=activation, statistics_batch_num=sparsifying_training_batches, num_iterations=1,
epochs_per_iteration=1)
self.patch_optimizer(self.update_mask) self.patch_optimizer(self.update_mask)
def _supported_dependency_aware(self): def _supported_dependency_aware(self):
......
...@@ -473,7 +473,7 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): ...@@ -473,7 +473,7 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
def __init__(self, model, pruner, statistics_batch_num=1): def __init__(self, model, pruner, statistics_batch_num=1):
super().__init__(model, pruner) super().__init__(model, pruner)
self.pruner.statistics_batch_num = statistics_batch_num self.statistics_batch_num = statistics_batch_num
self.pruner.iterations = 0 self.pruner.iterations = 0
self.pruner.set_wrappers_attribute("contribution", None) self.pruner.set_wrappers_attribute("contribution", None)
self.pruner.patch_optimizer(self.calc_contributions) self.pruner.patch_optimizer(self.calc_contributions)
...@@ -497,13 +497,13 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): ...@@ -497,13 +497,13 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
Calculate the estimated importance of filters as a sum of individual contribution Calculate the estimated importance of filters as a sum of individual contribution
based on the first order taylor expansion. based on the first order taylor expansion.
""" """
if self.pruner.iterations >= self.pruner.statistics_batch_num: if self.pruner.iterations >= self.statistics_batch_num:
return return
for wrapper in self.pruner.get_modules_wrapper(): for wrapper in self.pruner.get_modules_wrapper():
filters = wrapper.module.weight.size(0) filters = wrapper.module.weight.size(0)
contribution = ( contribution = (
wrapper.module.weight*wrapper.module.weight.grad).data.pow(2).view(filters, -1).sum(dim=1) wrapper.module.weight * wrapper.module.weight.grad).data.pow(2).view(filters, -1).sum(dim=1)
if wrapper.contribution is None: if wrapper.contribution is None:
wrapper.contribution = contribution wrapper.contribution = contribution
else: else:
...@@ -512,7 +512,7 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): ...@@ -512,7 +512,7 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
self.pruner.iterations += 1 self.pruner.iterations += 1
def get_channel_sum(self, wrapper, wrapper_idx): def get_channel_sum(self, wrapper, wrapper_idx):
if self.pruner.iterations < self.pruner.statistics_batch_num: if self.pruner.iterations < self.statistics_batch_num:
return None return None
if wrapper.contribution is None: if wrapper.contribution is None:
return None return None
...@@ -524,6 +524,8 @@ class ActivationFilterPrunerMasker(StructuredWeightMasker): ...@@ -524,6 +524,8 @@ class ActivationFilterPrunerMasker(StructuredWeightMasker):
super().__init__(model, pruner) super().__init__(model, pruner)
self.statistics_batch_num = statistics_batch_num self.statistics_batch_num = statistics_batch_num
self.pruner.hook_id = self._add_activation_collector(self.pruner) self.pruner.hook_id = self._add_activation_collector(self.pruner)
self.pruner.iterations = 0
self.pruner.patch_optimizer(self._iteration_counter)
assert activation in ['relu', 'relu6'] assert activation in ['relu', 'relu6']
if activation == 'relu': if activation == 'relu':
...@@ -533,6 +535,9 @@ class ActivationFilterPrunerMasker(StructuredWeightMasker): ...@@ -533,6 +535,9 @@ class ActivationFilterPrunerMasker(StructuredWeightMasker):
else: else:
self.pruner.activation = None self.pruner.activation = None
def _iteration_counter(self):
self.pruner.iterations += 1
def _add_activation_collector(self, pruner): def _add_activation_collector(self, pruner):
def collector(collected_activation): def collector(collected_activation):
def hook(module_, input_, output): def hook(module_, input_, output):
......
...@@ -201,7 +201,7 @@ class CompressorTestCase(TestCase): ...@@ -201,7 +201,7 @@ class CompressorTestCase(TestCase):
model = TorchModel() model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsity_training_epochs=1) pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsifying_training_batches=1)
x = torch.rand((1, 1, 28, 28), requires_grad=True) x = torch.rand((1, 1, 28, 28), requires_grad=True)
model.conv1.module.weight.data = torch.tensor(w1).float() model.conv1.module.weight.data = torch.tensor(w1).float()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment