"docs/archive_en_US/TrialExample/Trials.md" did not exist on "c84ba2578454ef12bbf6a3d8560f9fc27ad81038"
Unverified Commit f24dc27b authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] block sparse refactor (#4932)

parent 00e4debb
......@@ -13,8 +13,7 @@ from torch.nn import Module
from torch.optim import Optimizer
from nni.common.serializer import Traceable
from nni.algorithms.compression.v2.pytorch.base.pruner import Pruner
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema, config_list_canonical, OptimizerConstructHelper
from ..base import Pruner
from .tools import (
DataCollector,
......@@ -38,9 +37,11 @@ from .tools import (
NormalSparsityAllocator,
BankSparsityAllocator,
GlobalSparsityAllocator,
Conv2dDependencyAwareAllocator
DependencyAwareAllocator
)
from ..utils import CompressorSchema, config_list_canonical, OptimizerConstructHelper, Scaling
_logger = logging.getLogger(__name__)
__all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPruner', 'ActivationPruner',
......@@ -275,12 +276,12 @@ class NormPruner(BasicPruner):
else:
self.data_collector.reset()
if self.metrics_calculator is None:
self.metrics_calculator = NormMetricsCalculator(p=self.p, dim=0)
self.metrics_calculator = NormMetricsCalculator(p=self.p, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
if self.sparsity_allocator is None:
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self, dim=0)
self.sparsity_allocator = NormalSparsityAllocator(self, Scaling(kernel_size=[1], kernel_padding_mode='back'))
elif self.mode == 'dependency_aware':
self.sparsity_allocator = Conv2dDependencyAwareAllocator(self, 0, self.dummy_input)
self.sparsity_allocator = DependencyAwareAllocator(self, self.dummy_input, Scaling(kernel_size=[1], kernel_padding_mode='back'))
else:
raise NotImplementedError('Only support mode `normal` and `dependency_aware`')
......@@ -440,12 +441,12 @@ class FPGMPruner(BasicPruner):
else:
self.data_collector.reset()
if self.metrics_calculator is None:
self.metrics_calculator = DistMetricsCalculator(p=2, dim=0)
self.metrics_calculator = DistMetricsCalculator(p=2, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
if self.sparsity_allocator is None:
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self, dim=0)
self.sparsity_allocator = NormalSparsityAllocator(self, Scaling(kernel_size=[1], kernel_padding_mode='back'))
elif self.mode == 'dependency_aware':
self.sparsity_allocator = Conv2dDependencyAwareAllocator(self, 0, self.dummy_input)
self.sparsity_allocator = DependencyAwareAllocator(self, self.dummy_input, Scaling(kernel_size=[1], kernel_padding_mode='back'))
else:
raise NotImplementedError('Only support mode `normal` and `dependency_aware`')
......@@ -688,16 +689,16 @@ class ActivationPruner(BasicPruner):
else:
self.data_collector.reset(collector_infos=[collector_info]) # type: ignore
if self.metrics_calculator is None:
self.metrics_calculator = self._get_metrics_calculator()
self.metrics_calculator = self._create_metrics_calculator()
if self.sparsity_allocator is None:
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self, dim=0)
self.sparsity_allocator = NormalSparsityAllocator(self, Scaling(kernel_size=[1], kernel_padding_mode='back'))
elif self.mode == 'dependency_aware':
self.sparsity_allocator = Conv2dDependencyAwareAllocator(self, 0, self.dummy_input)
self.sparsity_allocator = DependencyAwareAllocator(self, self.dummy_input, Scaling(kernel_size=[1], kernel_padding_mode='back'))
else:
raise NotImplementedError('Only support mode `normal` and `dependency_aware`')
def _get_metrics_calculator(self) -> MetricsCalculator:
def _create_metrics_calculator(self) -> MetricsCalculator:
raise NotImplementedError()
......@@ -782,8 +783,8 @@ class ActivationAPoZRankPruner(ActivationPruner):
# return a matrix that the position of zero in `output` is one, others is zero.
return torch.eq(self._activation(output.detach()), torch.zeros_like(output)).type_as(output)
def _get_metrics_calculator(self) -> MetricsCalculator:
return APoZRankMetricsCalculator(dim=1)
def _create_metrics_calculator(self) -> MetricsCalculator:
return APoZRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back'))
class ActivationMeanRankPruner(ActivationPruner):
......@@ -865,8 +866,8 @@ class ActivationMeanRankPruner(ActivationPruner):
# return the activation of `output` directly.
return self._activation(output.detach())
def _get_metrics_calculator(self) -> MetricsCalculator:
return MeanRankMetricsCalculator(dim=1)
def _create_metrics_calculator(self) -> MetricsCalculator:
return MeanRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back'))
class TaylorFOWeightPruner(BasicPruner):
......@@ -1009,14 +1010,14 @@ class TaylorFOWeightPruner(BasicPruner):
else:
self.data_collector.reset(collector_infos=[collector_info]) # type: ignore
if self.metrics_calculator is None:
self.metrics_calculator = MultiDataNormMetricsCalculator(p=1, dim=0)
self.metrics_calculator = MultiDataNormMetricsCalculator(p=1, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
if self.sparsity_allocator is None:
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self, dim=0)
self.sparsity_allocator = NormalSparsityAllocator(self, Scaling(kernel_size=[1], kernel_padding_mode='back'))
elif self.mode == 'global':
self.sparsity_allocator = GlobalSparsityAllocator(self, dim=0)
self.sparsity_allocator = GlobalSparsityAllocator(self, Scaling(kernel_size=[1], kernel_padding_mode='back'))
elif self.mode == 'dependency_aware':
self.sparsity_allocator = Conv2dDependencyAwareAllocator(self, 0, self.dummy_input)
self.sparsity_allocator = DependencyAwareAllocator(self, self.dummy_input, Scaling(kernel_size=[1], kernel_padding_mode='back'))
else:
raise NotImplementedError('Only support mode `normal`, `global` and `dependency_aware`')
......@@ -1146,12 +1147,12 @@ class ADMMPruner(BasicPruner):
if self.granularity == 'fine-grained':
self.metrics_calculator = NormMetricsCalculator(p=1)
elif self.granularity == 'coarse-grained':
self.metrics_calculator = NormMetricsCalculator(dim=0, p=1)
self.metrics_calculator = NormMetricsCalculator(p=1, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
if self.sparsity_allocator is None:
if self.granularity == 'fine-grained':
self.sparsity_allocator = NormalSparsityAllocator(self)
elif self.granularity == 'coarse-grained':
self.sparsity_allocator = NormalSparsityAllocator(self, dim=0)
self.sparsity_allocator = NormalSparsityAllocator(self, Scaling(kernel_size=[1], kernel_padding_mode='back'))
def compress(self) -> Tuple[Module, Dict]:
"""
......
......@@ -25,7 +25,7 @@ from .sparsity_allocator import (
NormalSparsityAllocator,
BankSparsityAllocator,
GlobalSparsityAllocator,
Conv2dDependencyAwareAllocator
DependencyAwareAllocator
)
from .task_generator import (
AGPTaskGenerator,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from datetime import datetime
import logging
from pathlib import Path
......@@ -13,12 +14,24 @@ from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from nni.algorithms.compression.v2.pytorch.base import Pruner, LayerInfo, Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
from ...base import Pruner, LayerInfo, Task, TaskResult
from ...utils import OptimizerConstructHelper, Scaling
_logger = logging.getLogger(__name__)
def _get_scaler(scalers: Dict[str, Dict[str, Scaling]] | None, module_name: str, target_name: str) -> Scaling | None:
# Get scaler for the specific target in the specific module. Return None if don't find it.
# `module_name` is not used in current nni version, will support different modules using different scalers in the future.
if scalers:
default_module_scalers = scalers.get('_default', {})
default_target_scaler = default_module_scalers.get(target_name, default_module_scalers.get('_default', None))
module_scalers = scalers.get(module_name, {})
return module_scalers.get(target_name, module_scalers.get('_default', default_target_scaler))
else:
return None
class DataCollector:
"""
An abstract class for collect the data needed by the compressor.
......@@ -245,49 +258,21 @@ class MetricsCalculator:
Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
Example:
If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2].
Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Case 1: Directly collect the conv module weight as data to calculate the metric.
Then the data has size (32, 16, 3, 3).
Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=0` will set in `__init__`.
Case 2: Use the output of the conv module as data to calculate the metric.
Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2).
Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=1` will set in `__init__`.
In both of these two case, the metric of this module has size (32,).
block_sparse_size
This used to describe the block size a metric value represented. By default, None means the block size is ones(len(dim)).
Make sure len(dim) == len(block_sparse_size), and the block_sparse_size dimension position is corresponding to dim.
Example:
The under pruning weight size is (768, 768), and you want to apply a block sparse on dim=[0] with block size [64, 768],
then you can set block_sparse_size=[64]. The final metric size is (12,).
scalers
Scaler is used to scale the metrics' size. It scaling metric to the same size as the shrinked mask in the sparsity allocator.
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`.
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
Passing in `None` means no need to scale.
"""
def __init__(self, dim: Optional[Union[int, List[int]]] = None,
block_sparse_size: Optional[Union[int, List[int]]] = None):
self.dim = dim if not isinstance(dim, int) else [dim]
self.block_sparse_size = block_sparse_size if not isinstance(block_sparse_size, int) else [block_sparse_size]
if self.block_sparse_size is not None:
assert all(i >= 1 for i in self.block_sparse_size)
elif self.dim is not None:
self.block_sparse_size = [1] * len(self.dim)
if self.dim is not None:
assert all(i >= 0 for i in self.dim)
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size)))) # type: ignore
def __init__(self, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None):
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
def _get_scaler(self, module_name: str, target_name: str) -> Scaling:
scaler = _get_scaler(self.scalers, module_name, target_name)
return scaler if scaler else Scaling([1])
def calculate_metrics(self, data: Dict) -> Dict[str, Tensor]:
"""
......@@ -307,142 +292,120 @@ class MetricsCalculator:
class SparsityAllocator:
"""
An abstract class for allocate mask based on metrics.
A base class for allocating mask based on metrics.
Parameters
----------
pruner
The pruner that binded with this `SparsityAllocator`.
dim
The under pruning weight dimensions, which metric size should equal to the under pruning weight size on these dimensions.
None means one-to-one correspondence between pruned dimensions and metric, which equal to set `dim` as all under pruning weight dimensions.
The mask will expand to the weight size depend on `dim`.
Example:
The under pruning weight has size (2, 3, 4), and `dim=1` means the under pruning weight dimension is 1.
Then the metric should have a size (3,), i.e., `metric=[0.9, 0.1, 0.8]`.
Assuming by some kind of `SparsityAllocator` get the mask on weight dimension 1 `mask=[1, 0, 1]`,
then the dimension mask will expand to the final mask `[[[1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]], [[1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]]]`.
block_sparse_size
This used to describe the block size a metric value represented. By default, None means the block size is ones(len(dim)).
Make sure len(dim) == len(block_sparse_size), and the block_sparse_size dimension position is corresponding to dim.
Example:
The metric size is (12,), and block_sparse_size=[64], then the mask will expand to (768,) at first before expand with `dim`.
scalers
Scaler is used to scale the masks' size. It shrinks the mask of the same size as the pruning target to the same size as the metric,
or expands the mask of the same size as the metric to the same size as the pruning target.
If you want to use different scalers for different pruning targets in different modules, please use a dict `{module_name: {target_name: scaler}}`.
If allocator meets an unspecified module name, it will try to use `scalers['_default'][target_name]` to scale its mask.
If allocator meets an unspecified target name, it will try to use `scalers[module_name]['_default']` to scale its mask.
Passing in a scaler instead of a `dict` of scalers will be treated as passed in `{'_default': {'_default': scalers}}`.
Passing in `None` means no need to scale.
continuous_mask
Inherit the mask already in the wrapper if set True.
If set True, the part that has been masked will be masked first.
If set False, the part that has been masked may be unmasked due to the increase of its corresponding metric.
"""
def __init__(self, pruner: Pruner, dim: Optional[Union[int, List[int]]] = None,
block_sparse_size: Optional[Union[int, List[int]]] = None, continuous_mask: bool = True):
def __init__(self, pruner: Pruner, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None, continuous_mask: bool = True):
self.pruner = pruner
self.dim = dim if not isinstance(dim, int) else [dim]
self.block_sparse_size = block_sparse_size if not isinstance(block_sparse_size, int) else [block_sparse_size]
if self.block_sparse_size is not None:
assert all(i >= 1 for i in self.block_sparse_size)
elif self.dim is not None:
self.block_sparse_size = [1] * len(self.dim)
if self.dim is not None:
assert all(i >= 0 for i in self.dim)
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size)))) # type: ignore
self.scalers: Dict[str, Dict[str, Scaling]] | None = scalers if isinstance(scalers, (dict, type(None))) else {'_default': {'_default': scalers}} # type: ignore
self.continuous_mask = continuous_mask
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
def _get_scaler(self, module_name: str, target_name: str) -> Scaling | None:
return _get_scaler(self.scalers, module_name, target_name)
def _expand_mask(self, module_name: str, target_name: str, mask: Tensor) -> Tensor:
# Expand the shrinked mask to the pruning target size.
scaler = self._get_scaler(module_name=module_name, target_name=target_name)
if scaler:
wrapper = self.pruner.get_modules_wrapper()[module_name]
return scaler.expand(mask, getattr(wrapper, f'{target_name}_mask').shape)
else:
return mask.clone()
def _shrink_mask(self, module_name: str, target_name: str, mask: Tensor) -> Tensor:
# Shrink the mask by scaler, shrinked mask usually has the same size with metric.
scaler = self._get_scaler(module_name=module_name, target_name=target_name)
if scaler:
mask = (scaler.shrink(mask) != 0).type_as(mask)
return mask
def _continuous_mask(self, new_masks: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
# Set the already masked part in the metric to the minimum value.
target_name = 'weight'
for module_name, target_mask in new_masks.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
old_target_mask = getattr(wrapper, f'{target_name}_mask', None)
if old_target_mask is not None:
new_masks[module_name][target_name] = torch.min(target_mask[target_name], old_target_mask)
return new_masks
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
"""
Generate masks for metrics-dependent targets.
Parameters
----------
metrics
A metric dict. The key is the name of layer, the value is its metric.
The format is {module_name: weight_metric}.
The metric of `weight` usually has the same size with shrinked mask.
Return
------
Dict[str, Dict[str, Tensor]]
The format is {module_name: {target_name: mask}}.
Return the masks of the same size as its target.
"""
raise NotImplementedError()
def _expand_mask(self, name: str, mask: Tensor) -> Dict[str, Tensor]:
def special_target_masks_generation(self, masks: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
"""
Some pruning targets' mask generation depends on other targets, i.e., bias mask depends on weight mask.
This function is used to generate these masks, and it be called at the end of `generate_sparsity`.
Parameters
----------
name
The masked module name.
mask
The reduced mask with `self.dim` and `self.block_sparse_size`.
masks
The format is {module_name: {target_name: mask}}.
It is usually the return value of `common_target_masks_generation`.
"""
for module_name, module_masks in masks.items():
# generate bias mask, this may move to wrapper in the future
weight_mask = module_masks.get('weight', None)
wrapper = self.pruner.get_modules_wrapper()[module_name]
old_bias_mask = getattr(wrapper, 'bias_mask', None)
if weight_mask is not None and old_bias_mask is not None and weight_mask.shape[0] == old_bias_mask.shape[0]:
# keep dim 0 and reduce all other dims by sum
reduce_dims = [reduce_dim for reduce_dim in range(1, len(weight_mask.shape))]
# count unmasked number of values on dim 0 (output channel) of weight
unmasked_num_on_dim0 = weight_mask.sum(reduce_dims) if reduce_dims else weight_mask
module_masks['bias'] = (unmasked_num_on_dim0 != 0).type_as(old_bias_mask)
return masks
Returns
-------
Dict[str, Tensor]
The key is `weight` or `bias`, value is the final mask.
"""
weight_mask = mask.clone()
if self.block_sparse_size is not None:
# expend mask with block_sparse_size
expand_size = list(weight_mask.size())
reshape_size = list(weight_mask.size())
for i, block_width in reversed(list(enumerate(self.block_sparse_size))):
weight_mask = weight_mask.unsqueeze(i + 1)
expand_size.insert(i + 1, block_width)
reshape_size[i] *= block_width
weight_mask = weight_mask.expand(expand_size).reshape(reshape_size)
wrapper = self.pruner.get_modules_wrapper()[name]
weight_size = wrapper.weight.data.size() # type: ignore
if self.dim is None:
assert weight_mask.size() == weight_size
expand_mask = {'weight': weight_mask}
else:
# expand mask to weight size with dim
assert len(weight_mask.size()) == len(self.dim)
assert all(weight_size[j] == weight_mask.size(i) for i, j in enumerate(self.dim))
idxs = list(range(len(weight_size)))
[idxs.pop(i) for i in reversed(self.dim)]
for i in idxs:
weight_mask = weight_mask.unsqueeze(i)
expand_mask = {'weight': weight_mask.expand(weight_size).clone()}
# NOTE: assume we only mask output, so the mask and bias have a one-to-one correspondence.
# If we support more kind of masks, this place need refactor.
if wrapper.bias_mask is not None and weight_mask.size() == wrapper.bias_mask.size(): # type: ignore
expand_mask['bias'] = weight_mask.clone()
return expand_mask
def _compress_mask(self, mask: Tensor) -> Tensor:
"""
This function will reduce the mask with `self.dim` and `self.block_sparse_size`.
e.g., a mask tensor with size [50, 60, 70], self.dim is (0, 1), self.block_sparse_size is [10, 10].
Then, the reduced mask size is [50 / 10, 60 / 10] => [5, 6].
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
"""
The main function of `SparsityAllocator`, generate a set of masks based on the given metrics.
Parameters
----------
name
The masked module name.
mask
The entire mask has the same size with weight.
metrics
A metric dict with format {module_name: weight_metric}
Returns
-------
Tensor
Reduced mask.
"""
if self.dim is None or len(mask.size()) == 1:
mask = mask.clone()
else:
mask_dim = list(range(len(mask.size())))
for dim in self.dim:
mask_dim.remove(dim)
mask = torch.sum(mask, dim=mask_dim)
if self.block_sparse_size is not None:
# operation like pooling
lower_case_letters = 'abcdefghijklmnopqrstuvwxyz'
ein_expression = ''
for i, step in enumerate(self.block_sparse_size):
mask = mask.unfold(i, step, step)
ein_expression += lower_case_letters[i]
ein_expression = '...{},{}'.format(ein_expression, ein_expression)
mask = torch.einsum(ein_expression, mask, torch.ones(self.block_sparse_size).to(mask.device))
return (mask != 0).type_as(mask)
Dict[str, Dict[str, Tensor]]
The masks format is {module_name: {target_name: mask}}.
"""
masks = self.common_target_masks_generation(metrics)
masks = self.special_target_masks_generation(masks)
if self.continuous_mask:
masks = self._continuous_mask(masks)
return masks
class TaskGenerator:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict, List, Optional, Union
from __future__ import annotations
from typing import Dict, List
import torch
from torch import Tensor
from .base import MetricsCalculator
from ...utils import Scaling
__all__ = ['NormMetricsCalculator', 'MultiDataNormMetricsCalculator', 'DistMetricsCalculator',
'APoZRankMetricsCalculator', 'MeanRankMetricsCalculator', 'StraightMetricsCalculator']
......@@ -28,49 +31,28 @@ class NormMetricsCalculator(MetricsCalculator):
"""
Calculate the specify norm for each tensor in data.
L1, L2, Level, Slim pruner use this to calculate metric.
"""
def __init__(self, dim: Optional[Union[int, List[int]]] = None, p: Optional[Union[int, float]] = None):
"""
Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
Example:
If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2].
Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Case 1: Directly collect the conv module weight as data to calculate the metric.
Then the data has size (32, 16, 3, 3).
Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=0` will set in `__init__`.
Case 2: Use the output of the conv module as data to calculate the metric.
Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2).
Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=1` will set in `__init__`.
In both of these two case, the metric of this module has size (32,).
p
The order of norm. None means Frobenius norm.
scalers
Please view the base class `MetricsCalculator` docstring.
"""
super().__init__(dim=dim)
def __init__(self, p: int | float | None = None, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None):
super().__init__(scalers=scalers)
self.p = p if p is not None else 'fro'
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
def reduce_func(t: Tensor) -> Tensor:
return t.norm(p=self.p, dim=-1) # type: ignore
metrics = {}
for name, tensor in data.items():
keeped_dim = list(range(len(tensor.size()))) if self.dim is None else self.dim
across_dim = list(range(len(tensor.size())))
[across_dim.pop(i) for i in reversed(keeped_dim)]
if len(across_dim) == 0:
metrics[name] = tensor.abs()
else:
metrics[name] = tensor.norm(p=self.p, dim=across_dim) # type: ignore
target_name = 'weight'
for module_name, target_data in data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
return metrics
......@@ -90,66 +72,32 @@ class DistMetricsCalculator(MetricsCalculator):
"""
Calculate the sum of specify distance for each element with all other elements in specify `dim` in each tensor in data.
FPGM pruner uses this to calculate metric.
"""
def __init__(self, p: float, dim: Union[int, List[int]]):
"""
Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
Example:
If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2].
Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Case 1: Directly collect the conv module weight as data to calculate the metric.
Then the data has size (32, 16, 3, 3).
Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=0` will set in `__init__`.
Case 2: Use the output of the conv module as data to calculate the metric.
Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2).
Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=1` will set in `__init__`.
In both of these two case, the metric of this module has size (32,).
p
The order of norm.
The order of norm. None means Frobenius norm.
scalers
Please view the base class `MetricsCalculator` docstring.
"""
super().__init__(dim=dim)
self.p = p
def __init__(self, p: int | float | None = None, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None):
super().__init__(scalers=scalers)
self.p = p if p is not None else 'fro'
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
def reduce_func(t: Tensor) -> Tensor:
reshape_data = t.reshape(-1, t.shape[-1])
metric = torch.zeros(reshape_data.shape[0], device=reshape_data.device)
for i in range(reshape_data.shape[0]):
metric[i] = (reshape_data - reshape_data[i]).norm(p=self.p, dim=-1).sum() # type: ignore
return metric.reshape(t.shape[:-1])
metrics = {}
for name, tensor in data.items():
keeped_dim = list(range(len(tensor.size()))) if self.dim is None else self.dim
reorder_dim = list(keeped_dim)
reorder_dim.extend([i for i in range(len(tensor.size())) if i not in keeped_dim])
reorder_tensor = tensor.permute(*reorder_dim).clone()
metric = torch.ones(*reorder_tensor.size()[:len(keeped_dim)], device=reorder_tensor.device)
across_dim = list(range(len(keeped_dim), len(reorder_dim)))
idxs = metric.nonzero(as_tuple=False)
for idx in idxs:
other = reorder_tensor
for i in idx:
other = other[i]
other = other.clone()
if len(across_dim) == 0:
dist_sum = torch.abs(reorder_tensor - other).sum()
else:
dist_sum = torch.norm((reorder_tensor - other), p=self.p, dim=across_dim).sum() # type: ignore
# NOTE: this place need refactor when support layer level pruning.
tmp_metric = metric
for i in idx[:-1]:
tmp_metric = tmp_metric[i]
tmp_metric[idx[-1]] = dist_sum
metrics[name] = metric
target_name = 'weight'
for module_name, target_data in data.items():
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
return metrics
......@@ -161,19 +109,15 @@ class APoZRankMetricsCalculator(MetricsCalculator):
APoZRank pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]:
def reduce_func(t: Tensor) -> Tensor:
return 1 - t.mean(dim=-1)
metrics = {}
for name, (num, zero_counts) in data.items():
keeped_dim = list(range(len(zero_counts.size()))) if self.dim is None else self.dim
across_dim = list(range(len(zero_counts.size())))
[across_dim.pop(i) for i in reversed(keeped_dim)]
# The element number on each keeped_dim in zero_counts
total_size = num
for dim, dim_size in enumerate(zero_counts.size()):
if dim not in keeped_dim:
total_size *= dim_size
_apoz = torch.sum(zero_counts, dim=across_dim).type_as(zero_counts) / total_size
# NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned.
metrics[name] = torch.ones_like(_apoz) - _apoz
target_name = 'weight'
for module_name, target_data in data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
return metrics
......@@ -183,11 +127,14 @@ class MeanRankMetricsCalculator(MetricsCalculator):
This metric simply calculate the average on `self.dim`, then divide by the batch_number.
MeanRank pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]:
def reduce_func(t: Tensor) -> Tensor:
return t.mean(dim=-1)
metrics = {}
for name, (num, activation_sum) in data.items():
keeped_dim = list(range(len(activation_sum.size()))) if self.dim is None else self.dim
across_dim = list(range(len(activation_sum.size())))
[across_dim.pop(i) for i in reversed(keeped_dim)]
metrics[name] = torch.mean(activation_sum, across_dim) / num
target_name = 'weight'
for module_name, target_data in data.items():
target_data = target_data[1] / target_data[0]
scaler = self._get_scaler(module_name, target_name)
metrics[module_name] = scaler.shrink(target_data, reduce_func)
return metrics
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
from __future__ import annotations
import itertools
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Union
import numpy as np
import torch
from torch import Tensor
from nni.algorithms.compression.v2.pytorch.base import Pruner
from nni.common.graph_utils import TorchModuleGraph
from nni.compression.pytorch.utils.shape_dependency import ChannelDependency, GroupDependency
from .base import SparsityAllocator
from ...base import Pruner
from ...utils import Scaling
class NormalSparsityAllocator(SparsityAllocator):
"""
This allocator simply pruned the weight with smaller metrics in layer level.
This allocator directly masks the locations of each pruning target with lower metric values.
"""
def generate_sparsity(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
masks = {}
for name, wrapper in self.pruner.get_modules_wrapper().items():
# TODO: Support more target type in wrapper & config list refactor
target_name = 'weight'
for module_name, target_metric in metrics.items():
masks[module_name] = {}
wrapper = self.pruner.get_modules_wrapper()[module_name]
sparsity_rate = wrapper.config['total_sparsity']
assert name in metrics, 'Metric of {} is not calculated.'.format(name)
# We assume the metric value are all positive right now.
metric = metrics[name]
if self.continuous_mask:
metric *= self._compress_mask(wrapper.weight_mask) # type: ignore
prune_num = int(sparsity_rate * metric.numel())
if prune_num == 0:
threshold = metric.min() - 1
prune_num = int(sparsity_rate * target_metric.numel())
if prune_num != 0:
threshold = torch.topk(target_metric.reshape(-1), prune_num, largest=False)[0].max()
shrinked_mask = torch.gt(target_metric, threshold).type_as(target_metric)
else:
threshold = torch.topk(metric.view(-1), prune_num, largest=False)[0].max()
mask = torch.gt(metric, threshold).type_as(metric)
masks[name] = self._expand_mask(name, mask)
if self.continuous_mask:
masks[name]['weight'] *= wrapper.weight_mask
# target_metric should have the same size as shrinked_mask
shrinked_mask = torch.ones_like(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks
class BankSparsityAllocator(SparsityAllocator):
"""
In bank pruner, all values in weight are divided into different sub blocks each shape
aligned with balance_gran. Each sub block has the same sparsity which equal to the overall sparsity.
This allocator pruned the weight in the granularity of block.
"""
def __init__(self, pruner: Pruner, balance_gran: list):
super().__init__(pruner)
self.balance_gran = balance_gran
......@@ -54,199 +56,166 @@ class BankSparsityAllocator(SparsityAllocator):
assert isinstance(gran, int) and gran > 0, 'All values in list balance_gran \
should be type int and bigger than zero'
def generate_sparsity(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
masks = {}
for name, wrapper in self.pruner.get_modules_wrapper().items():
# TODO: Support more target type in wrapper & config list refactor
target_name = 'weight'
for module_name, target_metric in metrics.items():
masks[module_name] = {}
wrapper = self.pruner.get_modules_wrapper()[module_name]
sparsity_rate = wrapper.config['total_sparsity']
assert name in metrics, 'Metric of {} is not calculated.'.format(name)
# We assume the metric value are all positive right now.
metric = metrics[name]
if self.continuous_mask:
metric *= self._compress_mask(wrapper.weight_mask) # type: ignore
n_dim = len(metric.shape)
n_dim = len(target_metric.shape)
assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric'
# make up for balance_gran
balance_gran = [1] * (n_dim - len(self.balance_gran)) + self.balance_gran
for i, j in zip(metric.shape, balance_gran):
assert i % j == 0, 'Length of {} weight is not aligned with balance granularity'.format(name)
for i, j in zip(target_metric.shape, balance_gran):
assert i % j == 0, 'Length of {} {} is not aligned with balance granularity'.format(module_name, target_name)
mask = torch.zeros(metric.shape).type_as(metric)
loop_iters = [range(int(i / j)) for i, j in zip(metric.shape, balance_gran)]
# FIXME: The following code need refactor, do it after scaling refactor is done.
shrinked_mask = torch.ones(target_metric.shape).type_as(target_metric)
loop_iters = [range(int(i / j)) for i, j in zip(target_metric.shape, balance_gran)]
for iter_params in itertools.product(*loop_iters):
index_str_list = [f"{iter_param * gran}:{(iter_param+1) * gran}"\
for iter_param, gran in zip(iter_params, balance_gran)]
index_str = ",".join(index_str_list)
sub_metric_str = "metric[{}]".format(index_str)
sub_mask_str = "mask[{}] = mask_bank".format(index_str)
metric_bank = eval(sub_metric_str)
sub_metric_str = "target_metric[{}]".format(index_str)
sub_mask_str = "shrinked_mask[{}] = mask_bank".format(index_str)
metric_bank: Tensor = eval(sub_metric_str)
prune_num = int(sparsity_rate * metric_bank.numel())
if prune_num == 0:
threshold = metric_bank.min() -1
else:
threshold = torch.topk(metric_bank.reshape(-1), prune_num, largest=False)[0].max()
# mask_bank will be used in exec(sub_mask_str)
if prune_num != 0:
threshold = torch.topk(metric_bank.reshape(-1), prune_num, largest=False)[0].max()
mask_bank = torch.gt(metric_bank, threshold).type_as(metric_bank)
else:
mask_bank = torch.ones_like(metric_bank)
exec(sub_mask_str)
masks[name] = self._expand_mask(name, mask)
if self.continuous_mask:
masks[name]['weight'] *= wrapper.weight_mask
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks
class GlobalSparsityAllocator(SparsityAllocator):
"""
This allocator pruned the weight with smaller metrics in group level.
This means all layers in a group will sort metrics uniformly.
The layers with the same config in config_list is a group.
This allocator sorts all metrics as a whole, mask the locations of pruning target with lower metric value.
"""
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
masks = {}
# {group_index: {layer_name: metric}}
grouped_metrics = {idx: {name: metrics[name] for name in names}
for idx, names in self.pruner.generate_module_groups().items()}
for _, group_metric_dict in grouped_metrics.items():
threshold, sub_thresholds = self._calculate_threshold(group_metric_dict)
for name, metric in group_metric_dict.items():
mask = torch.gt(metric, min(threshold, sub_thresholds[name])).type_as(metric)
masks[name] = self._expand_mask(name, mask)
if self.continuous_mask:
masks[name]['weight'] *= self.pruner.get_modules_wrapper()[name].weight_mask
if not metrics:
return masks
def _calculate_threshold(self, group_metric_dict: Dict[str, Tensor]) -> Tuple[float, Dict[str, float]]:
# TODO: support more target type in wrapper & config list refactor
target_name = 'weight'
# validate all wrapper setting the same sparsity
# TODO: move validation logic to pruner
global_sparsity_rate = self.pruner.get_modules_wrapper()[list(metrics.keys())[0]].config['total_sparsity']
for module_name, target_metric in metrics.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
assert global_sparsity_rate == wrapper.config['total_sparsity']
# find the largest metric value among all metrics
max_metric_value = list(metrics.values())[0].max()
for module_name, target_metric in metrics.items():
max_metric_value = max_metric_value if max_metric_value >= target_metric.max() else target_metric.max()
# prevent each module from being over-pruned, prevent ratio is 'max_sparsity_per_layer'
for module_name, target_metric in metrics.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
max_sparsity = wrapper.config.get('max_sparsity_per_layer', {}).get(module_name, 0.99)
assert 0 <= max_sparsity <= 1
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
expand_times = old_target_mask.numel() // target_metric.numel()
max_pruning_numel = int(max_sparsity * target_metric.numel()) * expand_times
threshold = torch.topk(target_metric.reshape(-1), max_pruning_numel, largest=False)[0].max()
metrics[module_name] = torch.where(target_metric <= threshold, target_metric, max_metric_value)
# build the global_matric & calculate global threshold
metric_list = []
sub_thresholds = {}
total_weight_num = 0
temp_wrapper_config = self.pruner.get_modules_wrapper()[list(group_metric_dict.keys())[0]].config
total_sparsity = temp_wrapper_config['total_sparsity']
max_sparsity_per_layer = temp_wrapper_config.get('max_sparsity_per_layer', {})
for name, metric in group_metric_dict.items():
wrapper = self.pruner.get_modules_wrapper()[name]
# We assume the metric value are all positive right now.
if self.continuous_mask:
metric = metric * self._compress_mask(wrapper.weight_mask) # type: ignore
layer_weight_num = wrapper.weight.data.numel() # type: ignore
total_weight_num += layer_weight_num
expend_times = int(layer_weight_num / metric.numel())
retention_ratio = 1 - max_sparsity_per_layer.get(name, 1)
retention_numel = math.ceil(retention_ratio * layer_weight_num)
removed_metric_num = math.ceil(retention_numel / (wrapper.weight_mask.numel() / metric.numel())) # type: ignore
stay_metric_num = metric.numel() - removed_metric_num
if stay_metric_num <= 0:
sub_thresholds[name] = metric.min().item() - 1
continue
# Remove the weight parts that must be left
stay_metric = torch.topk(metric.view(-1), stay_metric_num, largest=False)[0]
sub_thresholds[name] = stay_metric.max()
if expend_times > 1:
stay_metric = stay_metric.expand(int(layer_weight_num / metric.numel()), stay_metric_num).contiguous().view(-1)
metric_list.append(stay_metric)
total_prune_num = int(total_sparsity * total_weight_num)
if total_prune_num == 0:
threshold = torch.cat(metric_list).min().item() - 1
else:
threshold = torch.topk(torch.cat(metric_list).view(-1), total_prune_num, largest=False)[0].max().item()
return threshold, sub_thresholds
for module_name, target_metric in metrics.items():
wrapper = self.pruner.get_modules_wrapper()[module_name]
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
expand_times = old_target_mask.numel() // target_metric.numel()
metric_list.append(target_metric.reshape(-1).unsqueeze(0).expand(expand_times, -1).reshape(-1))
global_metric = torch.cat(metric_list)
max_pruning_num = int((global_metric != max_metric_value).sum().item())
total_pruning_num = min(int(global_sparsity_rate * global_metric.numel()), max_pruning_num)
global_threshold = torch.topk(global_metric.reshape(-1), total_pruning_num, largest=False)[0].max()
# generate masks for each target
for module_name, target_metric in metrics.items():
masks[module_name] = {}
wrapper = self.pruner.get_modules_wrapper()[module_name]
shrinked_mask = torch.gt(target_metric, global_threshold).type_as(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks
class Conv2dDependencyAwareAllocator(SparsityAllocator):
class DependencyAwareAllocator(NormalSparsityAllocator):
"""
An allocator specific for Conv2d with dependency-aware.
An specific allocator for Conv2d & Linear module with dependency-aware.
It will generate a public mask for the modules that have dependencies,
then generate the part of the non-public mask for each module.
For other module types, the way to generate the mask is the same as `NormalSparsityAllocator`.
"""
def __init__(self, pruner: Pruner, dim: int, dummy_input: Any):
assert isinstance(dim, int), 'Only support single dim in Conv2dDependencyAwareAllocator.'
super().__init__(pruner, dim=dim)
self.dummy_input = dummy_input
def __init__(self, pruner: Pruner, dummy_input: Any, scalers: Dict[str, Dict[str, Scaling]] | Scaling | None = None):
# Scaling(kernel_size=[1], kernel_padding_mode='back') means output channel pruning.
scalers = scalers if scalers else Scaling(kernel_size=[1], kernel_padding_mode='back')
super().__init__(pruner, scalers=scalers)
self.channel_dependency, self.group_dependency = self._get_dependency(dummy_input)
def _get_dependency(self):
graph = self.pruner.generate_graph(dummy_input=self.dummy_input)
def _get_dependency(self, dummy_input: Any):
# get the channel dependency and group dependency
# channel dependency format: [[module_name1, module_name2], [module_name3], ...]
# group dependency format: {module_name: group_num}
self.pruner._unwrap_model()
self.channel_depen = ChannelDependency(model=self.pruner.bound_model, dummy_input=self.dummy_input, traced_model=graph.trace).dependency_sets
self.group_depen = GroupDependency(model=self.pruner.bound_model, dummy_input=self.dummy_input, traced_model=graph.trace).dependency_sets
graph = TorchModuleGraph(model=self.pruner.bound_model, dummy_input=dummy_input)
channel_dependency = ChannelDependency(model=self.pruner.bound_model, dummy_input=dummy_input, traced_model=graph.trace).dependency_sets
group_dependency = GroupDependency(model=self.pruner.bound_model, dummy_input=dummy_input, traced_model=graph.trace).dependency_sets
self.pruner._wrap_model()
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
self._get_dependency()
masks = {}
grouped_metrics = {}
grouped_names = set()
# combine metrics with channel dependence
for idx, names in enumerate(self.channel_depen):
grouped_metric = {name: metrics[name] for name in names if name in metrics}
grouped_names.update(grouped_metric.keys())
if self.continuous_mask:
for name, metric in grouped_metric.items():
metric *= self._compress_mask(self.pruner.get_modules_wrapper()[name].weight_mask) # type: ignore
if len(grouped_metric) > 0:
grouped_metrics[idx] = grouped_metric
# ungrouped metrics stand alone as a group
ungrouped_names = set(metrics.keys()).difference(grouped_names)
for name in ungrouped_names:
idx += 1 # type: ignore
grouped_metrics[idx] = {name: metrics[name]}
# generate masks
for _, group_metric_dict in grouped_metrics.items():
group_metric = self._group_metric_calculate(group_metric_dict)
sparsities = {name: self.pruner.get_modules_wrapper()[name].config['total_sparsity'] for name in group_metric_dict.keys()}
min_sparsity = min(sparsities.values())
# generate group mask
conv2d_groups, group_mask = [], []
for name in group_metric_dict.keys():
if name in self.group_depen:
conv2d_groups.append(self.group_depen[name])
else:
# not in group_depen means not a Conv2d layer, in this case, assume the group number is 1
conv2d_groups.append(1)
max_conv2d_group = np.lcm.reduce(conv2d_groups)
pruned_per_conv2d_group = int(group_metric.numel() / max_conv2d_group * min_sparsity)
conv2d_group_step = int(group_metric.numel() / max_conv2d_group)
for gid in range(max_conv2d_group):
_start = gid * conv2d_group_step
_end = (gid + 1) * conv2d_group_step
if pruned_per_conv2d_group > 0:
threshold = torch.topk(group_metric[_start: _end], pruned_per_conv2d_group, largest=False)[0].max()
conv2d_group_mask = torch.gt(group_metric[_start:_end], threshold).type_as(group_metric)
else:
conv2d_group_mask = torch.ones(conv2d_group_step, device=group_metric.device)
group_mask.append(conv2d_group_mask)
group_mask = torch.cat(group_mask, dim=0)
# generate final mask
for name, metric in group_metric_dict.items():
# We assume the metric value are all positive right now.
metric = metric * group_mask
pruned_num = int(sparsities[name] * len(metric))
if pruned_num == 0:
threshold = metric.min() - 1
return channel_dependency, group_dependency
def _metric_fuse(self, metrics: Union[Dict[str, Tensor], List[Tensor]]) -> Tensor:
# Sum all metric value in the same position.
metrics = list(metrics.values()) if isinstance(metrics, dict) else metrics
assert all(metrics[0].size() == metric.size() for metric in metrics), 'Metrics size do not match.'
fused_metric = torch.zeros_like(metrics[0])
for metric in metrics:
fused_metric += metric
return fused_metric
def common_target_masks_generation(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
# generate public part for modules that have dependencies
for module_names in self.channel_dependency:
sub_metrics = {module_name: metrics[module_name] for module_name in module_names if module_name in metrics}
if not sub_metrics:
continue
fused_metric = self._metric_fuse(sub_metrics)
sparsity_rates = {module_name: self.pruner.get_modules_wrapper()[module_name].config['total_sparsity'] for module_name in sub_metrics.keys()}
min_sparsity_rate = min(sparsity_rates.values())
group_nums = [self.group_dependency.get(module_name, 1) for module_name in sub_metrics.keys()]
max_group_nums = int(np.lcm.reduce(group_nums))
pruned_numel_per_group = int(fused_metric.numel() // max_group_nums * min_sparsity_rate)
group_step = fused_metric.shape[0] // max_group_nums
# get the public part of the mask of the module with dependencies
sub_masks = []
for gid in range(max_group_nums):
_start = gid * group_step
_end = (gid + 1) * group_step
if pruned_numel_per_group > 0:
threshold = torch.topk(fused_metric[_start: _end].reshape(-1), pruned_numel_per_group, largest=False)[0].max()
sub_mask = torch.gt(fused_metric[_start:_end], threshold).type_as(fused_metric)
else:
threshold = torch.topk(metric, pruned_num, largest=False)[0].max()
mask = torch.gt(metric, threshold).type_as(metric)
masks[name] = self._expand_mask(name, mask)
if self.continuous_mask:
masks[name]['weight'] *= self.pruner.get_modules_wrapper()[name].weight_mask
return masks
sub_mask = torch.ones_like(fused_metric[_start:_end])
sub_masks.append(sub_mask)
dependency_mask = torch.cat(sub_masks, dim=0)
def _group_metric_calculate(self, group_metrics: Union[Dict[str, Tensor], List[Tensor]]) -> Tensor:
"""
Add all metric value in the same position in one group.
"""
group_metrics = list(group_metrics.values()) if isinstance(group_metrics, dict) else group_metrics
assert all(group_metrics[0].size() == group_metric.size() for group_metric in group_metrics), 'Metrics size do not match.'
group_sum_metric = torch.zeros(group_metrics[0].size(), device=group_metrics[0].device)
for group_metric in group_metrics:
group_sum_metric += group_metric
return group_sum_metric
# change the metric value corresponding to the public mask part to the minimum value
for module_name, target_metric in sub_metrics.items():
min_value = target_metric.min()
metrics[module_name] = torch.where(dependency_mask!=0, target_metric, min_value)
return super().common_target_masks_generation(metrics)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .attr import (
get_nested_attr,
set_nested_attr
)
from .config_validation import CompressorSchema
from .constructor_helper import *
from .pruning import (
config_list_canonical,
unfold_config_list,
......@@ -12,4 +17,4 @@ from .pruning import (
get_model_weights_numel,
get_module_by_name
)
from .constructor_helper import *
from .scaling import Scaling
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from functools import reduce
from typing import Any, overload
@overload
def get_nested_attr(__o: object, __name: str) -> Any:
...
@overload
def get_nested_attr(__o: object, __name: str, __default: Any) -> Any:
...
def get_nested_attr(__o: object, __name: str, *args) -> Any:
"""
Get a nested named attribute from an object by a `.` separated name.
rgetattr(x, 'y.z') is equivalent to getattr(getattr(x, 'y'), 'z') and x.y.z.
"""
def _getattr(__o, __name):
return getattr(__o, __name, *args)
return reduce(_getattr, [__o] + __name.split('.')) # type: ignore
def set_nested_attr(__obj: object, __name: str, __value: Any):
"""
Set the nested named attribute on the given object to the specified value by a `.` separated name.
set_nested_attr(x, 'y.z', v) is equivalent to setattr(getattr(x, 'y'), 'z', v) x.y.z = v.
"""
pre, _, post = __name.rpartition('.')
return setattr(get_nested_attr(__obj, pre) if pre else __obj, post, __value)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from functools import reduce
from typing import Callable, List, overload
from typing_extensions import Literal
import torch
from torch import Tensor
class Scaling:
"""
In the process of generating masks, a large number of operations like pooling or upsampling are involved.
This class provides tensor-related scaling functions for a given scaling kernel.
Similar to the concept of convolutional kernel, the scaling kernel also moves over the tensor and does operations.
The scaling kernel in this class is defined by two parts, kernel size and scaling function (shrink and expand).
Parameters
----------
kernel_size
kernel_size is the scale, which determines how large a range in a tensor should shrink to a value,
or how large a value in a tensor should expand.
`-1` can be used to indicate that it is a full step in this dimension,
and the dimension where -1 is located will be reduced or unsqueezed during scaling.
Example::
kernel_size = [2, -1]
# For a given 2D-tensor with size (4, 3),
[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]]
# shrinking it by shrink function, its size becomes (2,) after shrinking:
[shrink([[1, 2, 3], [4, 5, 6]]), shrink([[7, 8, 9], [10, 11, 12]])]
# expanding it by expand function with a given expand size,
# if the expand function is repeating the values, and the expand size is (4, 6, 2):
[[[1, 1],
[1, 1],
[2, 2],
[2, 2],
[3, 3],
[3, 3]],
...
[9, 9]]]
# note that the original tensor with size (4, 3) will unsqueeze to size (4, 3, 1) at first
# for the `-1` in kernel_size, then expand size (4, 3, 1) to size (4, 6, 2).
kernel_padding_mode
'front' or 'back', default is 'front'.
If set 'front', for a given tensor when shrinking, padding `1` at front of kernel_size until `len(tensor.shape) == len(kernel_size)`;
for a given expand size when expanding, padding `1` at front of kernel_size until `len(expand_size) == len(kernel_size)`.
If set 'back', for a given tensor when shrinking, padding `-1` at back of kernel_size until `len(tensor.shape) == len(kernel_size)`;
for a given expand size when expanding, padding `-1` at back of kernel_size until `len(expand_size) == len(kernel_size)`.
"""
def __init__(self, kernel_size: List[int], kernel_padding_mode: Literal['front', 'back'] = 'front') -> None:
self.kernel_size = kernel_size
assert kernel_padding_mode in ['front', 'back'], f"kernel_padding_mode should be one of ['front', 'back'], but get kernel_padding_mode={kernel_padding_mode}."
self.kernel_padding_mode = kernel_padding_mode
def _padding(self, _list: List[int], length: int, padding_value: int = -1, padding_mode: Literal['front', 'back'] = 'back') -> List[int]:
"""
Padding the `_list` to a specific length with `padding_value`.
Parameters
----------
_list
The list of int value to be padding.
length
The length to pad to.
padding_value
Padding value, should be a int.
padding_mode
If `padding_mode` is `'front'`, then the padding applied on the front of the size list.
If `padding_mode` is `'back'`, then the padding applied on the back of the size list.
Returns
-------
List[int]
The padded list.
"""
assert len(_list) <= length
padding = [padding_value for _ in range(length - len(_list))]
if padding_mode == 'front':
new_list = padding + list(_list)
elif padding_mode == 'back':
new_list = list(_list) + padding
else:
raise ValueError(f'Unsupported padding mode: {padding_mode}.')
return new_list
def _shrink(self, target: Tensor, kernel_size: List[int], reduce_func: Callable[[Tensor], Tensor] | None = None) -> Tensor:
"""
Main logic about how to shrink target. Subclass could override this function to customize.
Sum all values covered by the kernel as a simple implementation.
"""
# step 1: put the part covered by the kernel to the end of the converted target.
# e.g., target size is [10, 20], kernel_size is [2, 4], then new_target size is [5, 5, 8].
reshape_size = []
final_size = []
reduced_dims = []
for (dim, step) in enumerate(kernel_size):
if step == -1:
step = target.shape[dim]
reduced_dims.insert(0, dim)
assert target.shape[dim] % step == 0
reshape_size.append(target.shape[dim] // step)
final_size.append(target.shape[dim] // step)
reshape_size.append(step)
permute_dims = [2 * _ for _ in range(len(kernel_size))] + [2 * _ + 1 for _ in range(len(kernel_size))]
converted_target = target.reshape(reshape_size).permute(permute_dims).reshape(final_size + [-1])
# step 2: reduce the converted_target last dim with a certain way, by default is converted_target.sum(-1).
result = reduce_func(converted_target) if reduce_func else converted_target.sum(-1)
# step 3: reduce the dims where kernel_size is -1.
# e.g., target size is [10, 40], kernel_size is [-1, 4], result size is [1, 10], then reduce result to size [10].
result = reduce(lambda t, dim: t.squeeze(dim), [result] + reduced_dims) # type: ignore
return result
def _expand(self, target: Tensor, kernel_size: List[int], expand_size: List[int]) -> Tensor:
"""
Main logic about how to expand target to a specific size. Subclass could override this function to customize.
Repeat each value to reach the kernel size as a simple implementation.
"""
# step 1: unsqueeze the target tensor where -1 is located in kernel_size.
unsqueezed_dims = [dim for (dim, step) in enumerate(kernel_size) if step == -1]
new_target: Tensor = reduce(lambda t, dim: t.unsqueeze(dim), [target] + unsqueezed_dims) # type: ignore
# step 2: build the _expand_size and unsqueeze target tensor on each dim
_expand_size = []
for a, b in zip(kernel_size, expand_size):
if a == -1:
_expand_size.append(1)
_expand_size.append(b)
else:
assert b % a == 0, f'Can not expand tensor with {target.shape} to {expand_size} with kernel size {kernel_size}.'
_expand_size.append(b // a)
_expand_size.append(a)
new_target: Tensor = reduce(lambda t, dim: t.unsqueeze(dim), [new_target] + [2 * _ + 1 for _ in range(len(expand_size))]) # type: ignore
# step 3: expanding the new target to _expand_size and reshape to expand_size.
# Note that we can also give an interface for how to expand the tensor, like `reduce_func` in `_shrink`, currently we don't have that need.
result = new_target.expand(_expand_size).reshape(expand_size).clone()
return result
def shrink(self, target: Tensor, reduce_func: Callable[[Tensor], Tensor] | None = None) -> Tensor:
# Canonicalize kernel_size to target size length at first.
# If kernel_padding_mode is 'front', padding 1 at the front of `self.kernel_size`.
# e.g., padding kernel_size [2, 2] to [1, 2, 2] when target size length is 3.
# If kernel_padding_mode is 'back', padding -1 at the back of `self.kernel_size`.
# e.g., padding kernel_size [1] to [1, -1, -1] when target size length is 3.
if self.kernel_padding_mode == 'front':
kernel_size = self._padding(self.kernel_size, len(target.shape), 1, 'front')
elif self.kernel_padding_mode == 'back':
kernel_size = self._padding(self.kernel_size, len(target.shape), -1, 'back')
else:
raise ValueError(f'Unsupported kernel padding mode: {self.kernel_padding_mode}.')
return self._shrink(target, kernel_size, reduce_func)
def expand(self, target: Tensor, expand_size: List[int]):
# Similar with `self.shrink`, canonicalize kernel_size to expand_size length at first.
if self.kernel_padding_mode == 'front':
kernel_size = self._padding(self.kernel_size, len(expand_size), 1, 'front')
elif self.kernel_padding_mode == 'back':
kernel_size = self._padding(self.kernel_size, len(expand_size), -1, 'back')
else:
raise ValueError(f'Unsupported kernel padding mode: {self.kernel_padding_mode}.')
return self._expand(target, kernel_size, expand_size)
@overload
def validate(self, target: List[int]):
...
@overload
def validate(self, target: Tensor):
...
def validate(self, target: List[int] | Tensor):
"""
Validate the target tensor can be shape-lossless scaling.
That means the shape will not change after `shrink` then `expand`.
"""
target = target if isinstance(target, Tensor) else torch.rand(target)
if self.expand((self.shrink(target)), list(target.shape)).shape != target.shape:
raise ValueError(f'The tensor with shape {target.shape}, can not shape-lossless scaling with ' +
f'kernel size is {self.kernel_size} and kernel_padding_mode is {self.kernel_padding_mode}.')
......@@ -26,6 +26,7 @@ from nni.algorithms.compression.v2.pytorch.pruning.tools import (
)
from nni.algorithms.compression.v2.pytorch.pruning.tools.base import HookCollectorInfo
from nni.algorithms.compression.v2.pytorch.utils import get_module_by_name
from nni.algorithms.compression.v2.pytorch.utils.scaling import Scaling
from nni.algorithms.compression.v2.pytorch.utils.constructor_helper import OptimizerConstructHelper
......@@ -112,7 +113,7 @@ class PruningToolsTestCase(unittest.TestCase):
def test_metrics_calculator(self):
# Test NormMetricsCalculator
metrics_calculator = NormMetricsCalculator(dim=0, p=2)
metrics_calculator = NormMetricsCalculator(p=2, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
data = {
'1': torch.ones(3, 3, 3),
'2': torch.ones(4, 4) * 2
......@@ -125,7 +126,7 @@ class PruningToolsTestCase(unittest.TestCase):
assert all(torch.equal(result[k], v) for k, v in metrics.items())
# Test DistMetricsCalculator
metrics_calculator = DistMetricsCalculator(dim=0, p=2)
metrics_calculator = DistMetricsCalculator(p=2, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
data = {
'1': torch.tensor([[1, 2], [4, 6]], dtype=torch.float32),
'2': torch.tensor([[0, 0], [1, 1]], dtype=torch.float32)
......@@ -138,7 +139,7 @@ class PruningToolsTestCase(unittest.TestCase):
assert all(torch.equal(result[k], v) for k, v in metrics.items())
# Test MultiDataNormMetricsCalculator
metrics_calculator = MultiDataNormMetricsCalculator(dim=0, p=1)
metrics_calculator = MultiDataNormMetricsCalculator(p=1, scalers=Scaling(kernel_size=[1], kernel_padding_mode='back'))
data = {
'1': [2, torch.ones(3, 3, 3) * 2],
'2': [2, torch.ones(4, 4) * 2]
......@@ -151,7 +152,7 @@ class PruningToolsTestCase(unittest.TestCase):
assert all(torch.equal(result[k], v) for k, v in metrics.items())
# Test APoZRankMetricsCalculator
metrics_calculator = APoZRankMetricsCalculator(dim=1)
metrics_calculator = APoZRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back'))
data = {
'1': [2, torch.tensor([[1, 1], [1, 1]], dtype=torch.float32)],
'2': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
......@@ -164,7 +165,7 @@ class PruningToolsTestCase(unittest.TestCase):
assert all(torch.equal(result[k], v) for k, v in metrics.items())
# Test MeanRankMetricsCalculator
metrics_calculator = MeanRankMetricsCalculator(dim=1)
metrics_calculator = MeanRankMetricsCalculator(Scaling(kernel_size=[-1, 1], kernel_padding_mode='back'))
data = {
'1': [2, torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)],
'2': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import pytest
import torch
from nni.algorithms.compression.v2.pytorch.utils.scaling import Scaling
def test_scaling():
data = torch.tensor([_ for _ in range(100)]).reshape(10, 10)
scaler = Scaling([5], kernel_padding_mode='front')
shrinked_data = scaler.shrink(data)
assert list(shrinked_data.shape) == [10, 2]
expanded_data = scaler.expand(data, [10, 50])
assert list(expanded_data.shape) == [10, 50]
scaler = Scaling([5, 5], kernel_padding_mode='back')
shrinked_data = scaler.shrink(data)
assert list(shrinked_data.shape) == [2, 2]
expanded_data = scaler.expand(data, [50, 50, 10])
assert list(expanded_data.shape) == [50, 50, 10]
scaler.validate([10, 10, 10])
if __name__ == '__main__':
test_scaling()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment