Unverified Commit 262150e7 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Bugbash] fix some pruner bugs (#5093)

parent d3e190c7
......@@ -15,6 +15,8 @@ import torch.nn.functional as F
from torch.nn import Module
from torch.optim import Optimizer
from nni.algorithms.compression.v2.pytorch.base.pruner import PrunerModuleWrapper
from ..base import Pruner
from .tools import (
......@@ -55,7 +57,8 @@ from ..utils import (
Evaluator,
ForwardHook,
TensorHook,
config_list_canonical
config_list_canonical,
get_output_batch_dims
)
from ..utils.docstring import _EVALUATOR_DOCSTRING
......@@ -189,12 +192,12 @@ class EvaluatorBasedPruner(BasicPruner):
for key, value in def_kwargs.items():
if key not in merged_kwargs and key in arg_names:
merged_kwargs[key] = value
diff = set(arg_names).difference(merged_kwargs.keys())
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
diff = set(merged_kwargs.keys()).difference(arg_names)
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() got {len(diff)} unexpected keyword argument: {diff}")
diff = set(arg_names).difference(merged_kwargs.keys())
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
return merged_kwargs
def compress(self) -> Tuple[Module, Dict]:
......@@ -747,15 +750,19 @@ class ActivationPruner(EvaluatorBasedPruner):
buffer.append(0)
def collect_activation(_module: Module, _input: Tensor, output: Tensor):
activation = self._activation_trans(output)
# TODO: remove `if` after deprecate the old API
if isinstance(_module, PrunerModuleWrapper):
_module = _module.module
batch_dims, batch_num = get_output_batch_dims(output, _module) # type: ignore
activation = self._activation_trans(output, batch_dims)
if len(buffer) == 1:
buffer.append(torch.zeros_like(activation))
if buffer[0] < self.training_steps:
buffer[1] += activation
buffer[0] += 1
buffer[1] += activation.to(buffer[1].device) # type: ignore
buffer[0] += batch_num
return collect_activation
def _activation_trans(self, output: Tensor) -> Tensor:
def _activation_trans(self, output: Tensor, dim: int | list = 0) -> Tensor:
raise NotImplementedError()
def reset_tools(self):
......@@ -846,9 +853,10 @@ class ActivationAPoZRankPruner(ActivationPruner):
For detailed example please refer to :githublink:`examples/model_compress/pruning/activation_pruning_torch.py <examples/model_compress/pruning/activation_pruning_torch.py>`
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
def _activation_trans(self, output: Tensor) -> Tensor:
def _activation_trans(self, output: Tensor, dim: int | list = 0) -> Tensor:
dim = [dim] if not isinstance(dim, (list, tuple)) else dim
# return a matrix that the position of zero in `output` is one, others is zero.
return torch.eq(self._activation(output.detach()), torch.zeros_like(output)).type_as(output).mean(0)
return torch.eq(self._activation(output.detach()), torch.zeros_like(output)).type_as(output).sum(dim=dim)
def _create_metrics_calculator(self) -> MetricsCalculator:
return APoZRankMetricsCalculator(Scaling(kernel_size=[1], kernel_padding_mode='back'))
......@@ -901,9 +909,10 @@ class ActivationMeanRankPruner(ActivationPruner):
For detailed example please refer to :githublink:`examples/model_compress/pruning/activation_pruning_torch.py <examples/model_compress/pruning/activation_pruning_torch.py>`
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
def _activation_trans(self, output: Tensor) -> Tensor:
def _activation_trans(self, output: Tensor, dim: int | list = 0) -> Tensor:
dim = [dim] if not isinstance(dim, (list, tuple)) else dim
# return the activation of `output` directly.
return self._activation(output.detach()).mean(0)
return self._activation(output.detach()).sum(dim)
def _create_metrics_calculator(self) -> MetricsCalculator:
return MeanRankMetricsCalculator(Scaling(kernel_size=[1], kernel_padding_mode='back'))
......
......@@ -62,12 +62,12 @@ class EvaluatorBasedPruningScheduler(BasePruningScheduler):
for key, value in def_kwargs.items():
if key not in merged_kwargs and key in arg_names:
merged_kwargs[key] = value
diff = set(arg_names).difference(merged_kwargs.keys())
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
diff = set(merged_kwargs.keys()).difference(arg_names)
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() got {len(diff)} unexpected keyword argument: {diff}")
diff = set(arg_names).difference(merged_kwargs.keys())
if diff:
raise TypeError(f"{self.__class__.__name__}.__init__() missing {len(diff)} required positional argument: {diff}")
return merged_kwargs
......
......@@ -98,7 +98,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
with Path(config_list_path).open('w') as f:
json_tricks.dump(new_config_list, f, indent=4)
task = Task(task_id, model_path, masks_path, config_list_path)
task = Task(task_id, model_path, masks_path, config_list_path, speedup=True, finetune=True, evaluate=False)
self._tasks[task_id] = task
......
......@@ -28,6 +28,7 @@ from .pruning import (
compute_sparsity_mask2compact,
compute_sparsity,
get_model_weights_numel,
get_module_by_name
get_module_by_name,
get_output_batch_dims
)
from .scaling import Scaling
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
from copy import deepcopy
import math
from typing import Dict, List, Tuple
import torch
......@@ -279,3 +280,21 @@ def get_module_by_name(model, module_name):
return model, leaf_module
else:
return None, None
def get_output_batch_dims(t: Tensor, module: Module):
if isinstance(module, (torch.nn.Linear, torch.nn.Bilinear)):
batch_nums = math.prod(t.shape[:-1])
batch_dims = [_ for _ in range(len(t.shape[:-1]))]
elif isinstance(module, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
batch_nums = math.prod(t.shape[:-2])
batch_dims = [_ for _ in range(len(t.shape[:-2]))]
elif isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
batch_nums = math.prod(t.shape[:-3])
batch_dims = [_ for _ in range(len(t.shape[:-3]))]
elif isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
batch_nums = math.prod(t.shape[:-4])
batch_dims = [_ for _ in range(len(t.shape[:-4]))]
else:
raise TypeError(f'Found unsupported module type in activation based pruner: {module.__class__.__name__}')
return batch_dims, batch_nums
......@@ -751,6 +751,8 @@ class Quantizer(Compressor):
_setattr(self.bound_model, wrapper.module_name, wrapper.module)
super()._unwrap_model()
# TODO: For most complex models, the information provided by input_shape is not enough to randomly initialize the complete input.
# And nni should not be responsible for exporting the onnx model, this feature should be deprecated in quantization refactor.
def export_model_save(self, model, model_path, calibration_config=None, calibration_path=None, onnx_path=None,
input_shape=None, device=None):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment