Unverified Commit 262150e7 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Bugbash] fix some pruner bugs (#5093)

parent d3e190c7
...@@ -15,6 +15,8 @@ import torch.nn.functional as F ...@@ -15,6 +15,8 @@ import torch.nn.functional as F
from torch.nn import Module from torch.nn import Module
from torch.optim import Optimizer from torch.optim import Optimizer
from nni.algorithms.compression.v2.pytorch.base.pruner import PrunerModuleWrapper
from ..base import Pruner from ..base import Pruner
from .tools import ( from .tools import (
...@@ -55,7 +57,8 @@ from ..utils import ( ...@@ -55,7 +57,8 @@ from ..utils import (
Evaluator, Evaluator,
ForwardHook, ForwardHook,
TensorHook, TensorHook,
config_list_canonical config_list_canonical,
get_output_batch_dims
) )
from ..utils.docstring import _EVALUATOR_DOCSTRING from ..utils.docstring import _EVALUATOR_DOCSTRING
...@@ -189,12 +192,12 @@ class EvaluatorBasedPruner(BasicPruner): ...@@ -189,12 +192,12 @@ class EvaluatorBasedPruner(BasicPruner):
for key, value in def_kwargs.items(): for key, value in def_kwargs.items():
if key not in merged_kwargs and key in arg_names: if key not in merged_kwargs and key in arg_names:
merged_kwargs[key] = value merged_kwargs[key] = value
diff = set(arg_names).difference(merged_kwargs.keys())
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
diff = set(merged_kwargs.keys()).difference(arg_names) diff = set(merged_kwargs.keys()).difference(arg_names)
if diff: if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() got {len(diff)} unexpected keyword argument: {diff}") raise TypeError(f"{self.__class__.__name__}.__init__() got {len(diff)} unexpected keyword argument: {diff}")
diff = set(arg_names).difference(merged_kwargs.keys())
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
return merged_kwargs return merged_kwargs
def compress(self) -> Tuple[Module, Dict]: def compress(self) -> Tuple[Module, Dict]:
...@@ -747,15 +750,19 @@ class ActivationPruner(EvaluatorBasedPruner): ...@@ -747,15 +750,19 @@ class ActivationPruner(EvaluatorBasedPruner):
buffer.append(0) buffer.append(0)
def collect_activation(_module: Module, _input: Tensor, output: Tensor): def collect_activation(_module: Module, _input: Tensor, output: Tensor):
activation = self._activation_trans(output) # TODO: remove `if` after deprecate the old API
if isinstance(_module, PrunerModuleWrapper):
_module = _module.module
batch_dims, batch_num = get_output_batch_dims(output, _module) # type: ignore
activation = self._activation_trans(output, batch_dims)
if len(buffer) == 1: if len(buffer) == 1:
buffer.append(torch.zeros_like(activation)) buffer.append(torch.zeros_like(activation))
if buffer[0] < self.training_steps: if buffer[0] < self.training_steps:
buffer[1] += activation buffer[1] += activation.to(buffer[1].device) # type: ignore
buffer[0] += 1 buffer[0] += batch_num
return collect_activation return collect_activation
def _activation_trans(self, output: Tensor) -> Tensor: def _activation_trans(self, output: Tensor, dim: int | list = 0) -> Tensor:
raise NotImplementedError() raise NotImplementedError()
def reset_tools(self): def reset_tools(self):
...@@ -846,9 +853,10 @@ class ActivationAPoZRankPruner(ActivationPruner): ...@@ -846,9 +853,10 @@ class ActivationAPoZRankPruner(ActivationPruner):
For detailed example please refer to :githublink:`examples/model_compress/pruning/activation_pruning_torch.py <examples/model_compress/pruning/activation_pruning_torch.py>` For detailed example please refer to :githublink:`examples/model_compress/pruning/activation_pruning_torch.py <examples/model_compress/pruning/activation_pruning_torch.py>`
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING) """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)
def _activation_trans(self, output: Tensor) -> Tensor: def _activation_trans(self, output: Tensor, dim: int | list = 0) -> Tensor:
dim = [dim] if not isinstance(dim, (list, tuple)) else dim
# return a matrix that the position of zero in `output` is one, others is zero. # return a matrix that the position of zero in `output` is one, others is zero.
return torch.eq(self._activation(output.detach()), torch.zeros_like(output)).type_as(output).mean(0) return torch.eq(self._activation(output.detach()), torch.zeros_like(output)).type_as(output).sum(dim=dim)
def _create_metrics_calculator(self) -> MetricsCalculator: def _create_metrics_calculator(self) -> MetricsCalculator:
return APoZRankMetricsCalculator(Scaling(kernel_size=[1], kernel_padding_mode='back')) return APoZRankMetricsCalculator(Scaling(kernel_size=[1], kernel_padding_mode='back'))
...@@ -901,9 +909,10 @@ class ActivationMeanRankPruner(ActivationPruner): ...@@ -901,9 +909,10 @@ class ActivationMeanRankPruner(ActivationPruner):
For detailed example please refer to :githublink:`examples/model_compress/pruning/activation_pruning_torch.py <examples/model_compress/pruning/activation_pruning_torch.py>` For detailed example please refer to :githublink:`examples/model_compress/pruning/activation_pruning_torch.py <examples/model_compress/pruning/activation_pruning_torch.py>`
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING) """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)
def _activation_trans(self, output: Tensor) -> Tensor: def _activation_trans(self, output: Tensor, dim: int | list = 0) -> Tensor:
dim = [dim] if not isinstance(dim, (list, tuple)) else dim
# return the activation of `output` directly. # return the activation of `output` directly.
return self._activation(output.detach()).mean(0) return self._activation(output.detach()).sum(dim)
def _create_metrics_calculator(self) -> MetricsCalculator: def _create_metrics_calculator(self) -> MetricsCalculator:
return MeanRankMetricsCalculator(Scaling(kernel_size=[1], kernel_padding_mode='back')) return MeanRankMetricsCalculator(Scaling(kernel_size=[1], kernel_padding_mode='back'))
......
...@@ -62,12 +62,12 @@ class EvaluatorBasedPruningScheduler(BasePruningScheduler): ...@@ -62,12 +62,12 @@ class EvaluatorBasedPruningScheduler(BasePruningScheduler):
for key, value in def_kwargs.items(): for key, value in def_kwargs.items():
if key not in merged_kwargs and key in arg_names: if key not in merged_kwargs and key in arg_names:
merged_kwargs[key] = value merged_kwargs[key] = value
diff = set(arg_names).difference(merged_kwargs.keys())
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
diff = set(merged_kwargs.keys()).difference(arg_names) diff = set(merged_kwargs.keys()).difference(arg_names)
if diff: if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() got {len(diff)} unexpected keyword argument: {diff}") raise TypeError(f"{self.__class__.__name__}.__init__() got {len(diff)} unexpected keyword argument: {diff}")
diff = set(arg_names).difference(merged_kwargs.keys())
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
return merged_kwargs return merged_kwargs
......
...@@ -98,7 +98,7 @@ class FunctionBasedTaskGenerator(TaskGenerator): ...@@ -98,7 +98,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
with Path(config_list_path).open('w') as f: with Path(config_list_path).open('w') as f:
json_tricks.dump(new_config_list, f, indent=4) json_tricks.dump(new_config_list, f, indent=4)
task = Task(task_id, model_path, masks_path, config_list_path) task = Task(task_id, model_path, masks_path, config_list_path, speedup=True, finetune=True, evaluate=False)
self._tasks[task_id] = task self._tasks[task_id] = task
......
...@@ -28,6 +28,7 @@ from .pruning import ( ...@@ -28,6 +28,7 @@ from .pruning import (
compute_sparsity_mask2compact, compute_sparsity_mask2compact,
compute_sparsity, compute_sparsity,
get_model_weights_numel, get_model_weights_numel,
get_module_by_name get_module_by_name,
get_output_batch_dims
) )
from .scaling import Scaling from .scaling import Scaling
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
from copy import deepcopy from copy import deepcopy
import math
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
...@@ -279,3 +280,21 @@ def get_module_by_name(model, module_name): ...@@ -279,3 +280,21 @@ def get_module_by_name(model, module_name):
return model, leaf_module return model, leaf_module
else: else:
return None, None return None, None
def get_output_batch_dims(t: Tensor, module: Module):
if isinstance(module, (torch.nn.Linear, torch.nn.Bilinear)):
batch_nums = math.prod(t.shape[:-1])
batch_dims = [_ for _ in range(len(t.shape[:-1]))]
elif isinstance(module, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
batch_nums = math.prod(t.shape[:-2])
batch_dims = [_ for _ in range(len(t.shape[:-2]))]
elif isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
batch_nums = math.prod(t.shape[:-3])
batch_dims = [_ for _ in range(len(t.shape[:-3]))]
elif isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
batch_nums = math.prod(t.shape[:-4])
batch_dims = [_ for _ in range(len(t.shape[:-4]))]
else:
raise TypeError(f'Found unsupported module type in activation based pruner: {module.__class__.__name__}')
return batch_dims, batch_nums
...@@ -751,6 +751,8 @@ class Quantizer(Compressor): ...@@ -751,6 +751,8 @@ class Quantizer(Compressor):
_setattr(self.bound_model, wrapper.module_name, wrapper.module) _setattr(self.bound_model, wrapper.module_name, wrapper.module)
super()._unwrap_model() super()._unwrap_model()
# TODO: For most complex models, the information provided by input_shape is not enough to randomly initialize the complete input.
# And nni should not be responsible for exporting the onnx model, this feature should be deprecated in quantization refactor.
def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None, def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None,
input_shape=None, device=None): input_shape=None, device=None):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment