Unverified Commit e5c3ac63 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

Compression v2 Stage 1 (#3917)

parent e219bae8
import argparse
import logging
from pathlib import Path
import torch
from torchvision import transforms, datasets
from nni.algorithms.compression.v2.pytorch import pruning
from nni.compression.pytorch import ModelSpeedup
from examples.model_compress.models.cifar10.vgg import VGG
logging.getLogger().setLevel(logging.DEBUG)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VGG().to(device)
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=200, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
def trainer(model, optimizer, criterion, epoch=None):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def evaluator(model):
model.eval()
criterion = torch.nn.NLLLoss()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)
print('Test Loss: {} Accuracy: {}%\n'.format(
test_loss, acc))
return acc
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
fintune_optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
def main(args):
if args.pre_train:
for i in range(1):
trainer(model, fintune_optimizer, criterion, epoch=i)
config_list = [{
'op_types': ['Conv2d'],
'sparsity_per_layer': 0.8
}]
kwargs = {
'model': model,
'config_list': config_list,
}
if args.pruner == 'level':
pruner = pruning.LevelPruner(**kwargs)
else:
kwargs['mode'] = args.mode
if kwargs['mode'] == 'dependency_aware':
kwargs['dummy_input'] = torch.rand(10, 3, 32, 32).to(device)
if args.pruner == 'l1norm':
pruner = pruning.L1NormPruner(**kwargs)
elif args.pruner == 'l2norm':
pruner = pruning.L2NormPruner(**kwargs)
elif args.pruner == 'fpgm':
pruner = pruning.FPGMPruner(**kwargs)
else:
kwargs['trainer'] = trainer
kwargs['optimizer'] = optimizer
kwargs['criterion'] = criterion
if args.pruner == 'slim':
kwargs['config_list'] = [{
'op_types': ['BatchNorm2d'],
'total_sparsity': 0.8,
'max_sparsity_per_layer': 0.9
}]
kwargs['training_epochs'] = 1
pruner = pruning.SlimPruner(**kwargs)
elif args.pruner == 'mean_activation':
pruner = pruning.ActivationMeanRankPruner(**kwargs)
elif args.pruner == 'apoz':
pruner = pruning.ActivationAPoZRankPruner(**kwargs)
elif args.pruner == 'taylorfo':
pruner = pruning.TaylorFOWeightPruner(**kwargs)
pruned_model, masks = pruner.compress()
pruner.show_pruned_weights()
if args.speed_up:
tmp_masks = {}
for name, mask in masks.items():
tmp_masks[name] = {}
tmp_masks[name]['weight'] = mask.get('weight_mask')
if 'bias' in masks:
tmp_masks[name]['bias'] = mask.get('bias_mask')
torch.save(tmp_masks, Path('./temp_masks.pth'))
pruner._unwrap_model()
ModelSpeedup(model, torch.rand(10, 3, 32, 32).to(device), Path('./temp_masks.pth'))
if args.finetune:
for i in range(1):
trainer(pruned_model, fintune_optimizer, criterion, epoch=i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Example for model comporession')
parser.add_argument('--pruner', type=str, default='l1norm',
choices=['level', 'l1norm', 'l2norm', 'slim',
'fpgm', 'mean_activation', 'apoz', 'taylorfo'],
help='pruner to use')
parser.add_argument('--mode', type=str, default='normal',
choices=['normal', 'dependency_aware', 'global'])
parser.add_argument('--pre-train', action='store_true', default=False,
help='Whether to pre-train the model')
parser.add_argument('--speed-up', action='store_true', default=False,
help='Whether to speed-up the pruned model')
parser.add_argument('--finetune', action='store_true', default=False,
help='Whether to finetune the pruned model')
args = parser.parse_args()
main(args)
from .compressor import Compressor, LayerInfo
from .pruner import Pruner, PrunerModuleWrapper
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
import logging
from typing import List, Dict, Optional, OrderedDict, Tuple, Any
import torch
from torch.nn import Module
from nni.common.graph_utils import TorchModuleGraph
from nni.compression.pytorch.utils import get_module_by_name
_logger = logging.getLogger(__name__)
__all__ = ['LayerInfo', 'Compressor']
class LayerInfo:
def __init__(self, name: str, module: Module):
self.module = module
self.name = name
self.type = type(module).__name__
def _setattr(model: Module, name: str, module: Module):
parent_module, _ = get_module_by_name(model, name)
if parent_module is not None:
name_list = name.split(".")
setattr(parent_module, name_list[-1], module)
else:
raise '{} not exist.'.format(name)
weighted_modules = [
'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d',
'Linear', 'Bilinear',
'PReLU',
'Embedding', 'EmbeddingBag',
]
class Compressor:
"""
The abstract base pytorch compressor.
"""
def __init__(self, model: Module, config_list: List[Dict]):
"""
Parameters
----------
model
The model under compressed.
config_list
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
"""
assert isinstance(model, Module)
self.is_wrapped = False
self.reset(model=model, config_list=config_list)
def reset(self, model: Module, config_list: List[Dict]):
"""
Reset the compressor with model and config_list.
Parameters
----------
model
The model under compressed.
config_list
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
"""
assert isinstance(model, Module), 'Only support compressing pytorch Module, but the type of model is {}.'.format(type(model))
self.bound_model = model
self.config_list = config_list
self.validate_config(model=model, config_list=config_list)
self._unwrap_model()
self._modules_to_compress = None
self.modules_wrapper = collections.OrderedDict()
for layer, config in self._detect_modules_to_compress():
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper[layer.name] = wrapper
self._wrap_model()
def _detect_modules_to_compress(self) -> List[Tuple[LayerInfo, Dict]]:
"""
Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method.
"""
if self._modules_to_compress is None:
self._modules_to_compress = []
for name, module in self.bound_model.named_modules():
if module == self.bound_model:
continue
layer = LayerInfo(name, module)
config = self._select_config(layer)
if config is not None:
self._modules_to_compress.append((layer, config))
return self._modules_to_compress
def _select_config(self, layer: LayerInfo) -> Optional[Dict]:
"""
Find the configuration for `layer` by parsing `self.config_list`.
Parameters
----------
layer
The layer that need to check if has compression configuration.
Returns
-------
Optional[Dict]
The retrieved configuration for this layer, if None, this layer should not be compressed.
"""
ret = None
for config in self.config_list:
config = config.copy()
# expand config if key `default` is in config['op_types']
if 'op_types' in config and 'default' in config['op_types']:
expanded_op_types = []
for op_type in config['op_types']:
if op_type == 'default':
expanded_op_types.extend(weighted_modules)
else:
expanded_op_types.append(op_type)
config['op_types'] = expanded_op_types
# check if condition is satisified
if 'op_types' in config and layer.type not in config['op_types']:
continue
if 'op_names' in config and layer.name not in config['op_names']:
continue
ret = config
if ret is None or 'exclude' in ret:
return None
return ret
def get_modules_wrapper(self) -> OrderedDict[str, Module]:
"""
Returns
-------
OrderedDict[str, Module]
An ordered dict, key is the name of the module, value is the wrapper of the module.
"""
return self.modules_wrapper
def _wrap_model(self):
"""
Wrap all modules that needed to be compressed.
"""
if not self.is_wrapped:
for _, wrapper in reversed(self.get_modules_wrapper().items()):
_setattr(self.bound_model, wrapper.name, wrapper)
self.is_wrapped = True
def _unwrap_model(self):
"""
Unwrap all modules that needed to be compressed.
"""
if self.is_wrapped:
for _, wrapper in self.get_modules_wrapper().items():
_setattr(self.bound_model, wrapper.name, wrapper.module)
self.is_wrapped = False
def set_wrappers_attribute(self, name: str, value: Any):
"""
To register attributes used in wrapped module's forward method.
If the type of the value is Torch.tensor, then this value is registered as a buffer in wrapper,
which will be saved by model.state_dict. Otherwise, this value is just a regular variable in wrapper.
Parameters
----------
name
Name of the variable.
value
Value of the variable.
"""
for wrapper in self.get_modules_wrapper():
if isinstance(value, torch.Tensor):
wrapper.register_buffer(name, value.clone())
else:
setattr(wrapper, name, value)
def generate_graph(self, dummy_input: Any) -> TorchModuleGraph:
"""
Generate a `TorchModuleGraph` instance of `self.bound_model` based on `jit.trace`.
Parameters
----------
dummy_input
The dummy input for `jit.trace`, users should put it on right device before pass in.
Returns
-------
TorchModuleGraph
A `TorchModuleGraph` instance.
"""
self._unwrap_model()
graph = TorchModuleGraph(model=self.bound_model, dummy_input=dummy_input)
self._wrap_model()
return graph
def generate_module_groups(self) -> Dict[int, List[str]]:
"""
Get all module names in each config in config_list.
Returns
-------
Dict[int, List[str]]
A dict. The key is the config idx in config_list, the value is the module name list. i.e., {1: ['layer.0', 'layer.2']}.
"""
self._unwrap_model()
module_groups = {}
for name, module in self.bound_model.named_modules():
if module == self.bound_model:
continue
layer = LayerInfo(name, module)
ret = None
for idx, config in enumerate(self.config_list):
config = config.copy()
# expand config if key `default` is in config['op_types']
if 'op_types' in config and 'default' in config['op_types']:
expanded_op_types = []
for op_type in config['op_types']:
if op_type == 'default':
expanded_op_types.extend(weighted_modules)
else:
expanded_op_types.append(op_type)
config['op_types'] = expanded_op_types
# check if condition is satisified
if 'op_types' in config and layer.type not in config['op_types']:
continue
if 'op_names' in config and layer.name not in config['op_names']:
continue
ret = (idx, config)
if ret is not None and 'exclude' not in ret[1]:
module_groups.setdefault(ret[0], [])
module_groups[ret[0]].append(name)
self._wrap_model()
return module_groups
def _wrap_modules(self, layer: LayerInfo, config: Dict):
"""
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
Parameters
----------
layer
the layer to instrument the compression operation
config
the configuration for compressing this layer
"""
raise NotImplementedError()
def validate_config(self, model: Module, config_list: List[Dict]):
"""
Subclass can optionally implement this method to check if config_list is valid.
Parameters
----------
model
The model under compressed.
config_list
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
"""
pass
def compress(self) -> Module:
"""
Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method.
`self._modules_to_compress` records all the to-be-compressed layers.
Returns
-------
torch.nn.Module
model with specified modules compressed.
"""
return self.bound_model
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Module
from .compressor import Compressor, LayerInfo
_logger = logging.getLogger(__name__)
__all__ = ['Pruner']
class PrunerModuleWrapper(Module):
def __init__(self, module: Module, module_name: str, config: Dict, pruner: Compressor):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module
The module user wants to compress.
config
The configurations that users specify for compression.
module_name
The name of the module to compress, wrapper module shares same name.
pruner
The pruner used to calculate mask.
"""
super().__init__()
# origin layer information
self.module = module
self.name = module_name
# config and pruner
self.config = config
self.pruner = pruner
# register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
else:
self.register_buffer("bias_mask", None)
def forward(self, *inputs):
# apply mask to weight, bias
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask)
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask)
return self.module(*inputs)
class Pruner(Compressor):
"""
The abstract class for pruning algorithm. Inherit this class and implement the `_reset_tools` to customize a pruner.
"""
def reset(self, model: Optional[Module] = None, config_list: Optional[List[Dict]] = None):
super().reset(model=model, config_list=config_list)
def _wrap_modules(self, layer: LayerInfo, config: Dict):
"""
Create a wrapper module to replace the original one.
Parameters
----------
layer
The layer to instrument the mask.
config
The configuration for generating the mask.
"""
_logger.debug("Module detected to compress : %s.", layer.name)
wrapper = PrunerModuleWrapper(layer.module, layer.name, config, self)
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
# move newly registered buffers to the same device of weight
wrapper.to(layer.module.weight.device)
return wrapper
def load_masks(self, masks: Dict[str, Dict[str, Tensor]]):
"""
Load an exist masks on the wrapper. You can train the model with an exist masks after load the masks.
Parameters
----------
masks
The masks dict with format {'op_name': {'weight_mask': mask, 'bias_mask': mask}}.
"""
wrappers = self.get_modules_wrapper()
for name, layer_mask in masks.items():
assert name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(name)
for mask_type, mask in layer_mask.items():
assert hasattr(wrappers[name], mask_type), 'there is no attribute {} in wrapper'.format(mask_type)
setattr(wrappers[name], mask_type, mask)
def compress(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]]]:
"""
Returns
-------
Tuple[Module, Dict]
Return the wrapped model and mask.
"""
return self.bound_model, {}
# NOTE: need refactor dim with supporting list
def show_pruned_weights(self, dim: int = 0):
"""
Log the simulated prune sparsity.
Parameters
----------
dim
The pruned dim.
"""
for _, wrapper in self.get_modules_wrapper().items():
weight_mask = wrapper.weight_mask
mask_size = weight_mask.size()
if len(mask_size) == 1:
index = torch.nonzero(weight_mask.abs() != 0, as_tuple=False).tolist()
else:
sum_idx = list(range(len(mask_size)))
sum_idx.remove(dim)
index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0, as_tuple=False).tolist()
_logger.info(f'simulated prune {wrapper.name} remain/total: {len(index)}/{weight_mask.size(dim)}')
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export pruned model weights, masks and onnx model(optional)
Parameters
----------
model_path
Path to save pruned model state_dict.
mask_path
(optional) path to save mask dict.
onnx_path
(optional) path to save onnx model.
input_shape
Input shape to onnx model.
device
Device of the model, used to place the dummy input tensor for exporting onnx file.
The tensor is placed on cpu if ```device``` is None.
"""
assert model_path is not None, 'model_path must be specified'
mask_dict = {}
self._unwrap_model() # used for generating correct state_dict name without wrapper state
for name, wrapper in self.get_modules_wrapper().items():
weight_mask = wrapper.weight_mask
bias_mask = wrapper.bias_mask
if weight_mask is not None:
mask_sum = weight_mask.sum().item()
mask_num = weight_mask.numel()
_logger.debug('Layer: %s Sparsity: %.4f', name, 1 - mask_sum / mask_num)
wrapper.module.weight.data = wrapper.module.weight.data.mul(weight_mask)
if bias_mask is not None:
wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask)
# save mask to dict
mask_dict[name] = {"weight_mask": weight_mask, "bias_mask": bias_mask}
torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path)
if mask_path is not None:
torch.save(mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path)
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
self._wrap_model()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import List, Dict, Tuple, Callable, Optional
from schema import And, Optional as SchemaOptional
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Module
from torch.optim import Optimizer
from nni.algorithms.compression.v2.pytorch.base.pruner import Pruner
from nni.algorithms.compression.v2.pytorch.utils.config_validation import PrunerSchema
from .tools import (
DataCollector,
HookCollectorInfo,
WeightDataCollector,
WeightTrainerBasedDataCollector,
SingleHookTrainerBasedDataCollector
)
from .tools import (
MetricsCalculator,
NormMetricsCalculator,
MultiDataNormMetricsCalculator,
DistMetricsCalculator,
APoZRankMetricsCalculator,
MeanRankMetricsCalculator
)
from .tools import (
SparsityAllocator,
NormalSparsityAllocator,
GlobalSparsityAllocator,
Conv2dDependencyAwareAllocator
)
_logger = logging.getLogger(__name__)
__all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPruner', 'ActivationPruner',
'ActivationAPoZRankPruner', 'ActivationMeanRankPruner', 'TaylorFOWeightPruner']
class OneShotPruner(Pruner):
def __init__(self, model: Module, config_list: List[Dict]):
self.data_collector: DataCollector = None
self.metrics_calculator: MetricsCalculator = None
self.sparsity_allocator: SparsityAllocator = None
self._convert_config_list(config_list)
super().__init__(model, config_list)
def _convert_config_list(self, config_list: List[Dict]):
"""
Convert `sparsity` in config to `sparsity_per_layer`.
"""
for config in config_list:
if 'sparsity' in config:
if 'sparsity_per_layer' in config:
raise ValueError("'sparsity' and 'sparsity_per_layer' have the same semantics, can not set both in one config.")
else:
config['sparsity_per_layer'] = config.pop('sparsity')
def reset(self, model: Optional[Module], config_list: Optional[List[Dict]]):
super().reset(model=model, config_list=config_list)
self.reset_tools()
def reset_tools(self):
"""
This function is used to reset `self.data_collector`, `self.metrics_calculator` and `self.sparsity_allocator`.
The subclass needs to implement this function to complete the pruning process.
See `compress()` to understand how NNI use these three part to generate mask for the bound model.
"""
raise NotImplementedError()
def compress(self) -> Tuple[Module, Dict]:
"""
Used to generate the mask. Pruning process is divided in three stages.
`self.data_collector` collect the data used to calculate the specify metric.
`self.metrics_calculator` calculate the metric and `self.sparsity_allocator` generate the mask depend on the metric.
Returns
-------
Tuple[Module, Dict]
Return the wrapped model and mask.
"""
data = self.data_collector.collect()
_logger.debug('Collected Data:\n%s', data)
metrics = self.metrics_calculator.calculate_metrics(data)
_logger.debug('Metrics Calculate:\n%s', metrics)
masks = self.sparsity_allocator.generate_sparsity(metrics)
_logger.debug('Masks:\n%s', masks)
self.load_masks(masks)
return self.bound_model, masks
class LevelPruner(OneShotPruner):
def __init__(self, model: Module, config_list: List[Dict]):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- op_types : Operation types to prune.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
"""
self.mode = 'normal'
super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str],
SchemaOptional('exclude'): bool
}], model, _logger)
schema.validate(config_list)
def reset_tools(self):
if self.data_collector is None:
self.data_collector = WeightDataCollector(self)
else:
self.data_collector.reset()
if self.metrics_calculator is None:
self.metrics_calculator = NormMetricsCalculator()
if self.sparsity_allocator is None:
self.sparsity_allocator = NormalSparsityAllocator(self)
class NormPruner(OneShotPruner):
def __init__(self, model: Module, config_list: List[Dict], p: int,
mode: str = 'normal', dummy_input: Optional[Tensor] = None):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- op_types : Conv2d and Linear are supported in NormPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
p
The order of norm.
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
self.p = p
self.mode = mode
self.dummy_input = dummy_input
super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('op_types'): ['Conv2d', 'Linear'],
SchemaOptional('op_names'): [str],
SchemaOptional('exclude'): bool
}], model, _logger)
schema.validate(config_list)
def reset_tools(self):
if self.data_collector is None:
self.data_collector = WeightDataCollector(self)
else:
self.data_collector.reset()
if self.metrics_calculator is None:
self.metrics_calculator = NormMetricsCalculator(p=self.p, dim=0)
if self.sparsity_allocator is None:
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self, dim=0)
elif self.mode == 'dependency_aware':
self.sparsity_allocator = Conv2dDependencyAwareAllocator(self, 0, self.dummy_input)
else:
raise NotImplementedError('Only support mode `normal` and `dependency_aware`')
class L1NormPruner(NormPruner):
def __init__(self, model: Module, config_list: List[Dict],
mode: str = 'normal', dummy_input: Optional[Tensor] = None):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- op_types : Conv2d and Linear are supported in L1NormPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the l1-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
super().__init__(model, config_list, 1, mode, dummy_input)
class L2NormPruner(NormPruner):
def __init__(self, model: Module, config_list: List[Dict],
mode: str = 'normal', dummy_input: Optional[Tensor] = None):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- op_types : Conv2d and Linear are supported in L2NormPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
super().__init__(model, config_list, 2, mode, dummy_input)
class FPGMPruner(OneShotPruner):
def __init__(self, model: Module, config_list: List[Dict],
mode: str = 'normal', dummy_input: Optional[Tensor] = None):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- op_types : Conv2d and Linear are supported in FPGMPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the FPGM of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
self.mode = mode
self.dummy_input = dummy_input
super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('op_types'): ['Conv2d', 'Linear'],
SchemaOptional('op_names'): [str],
SchemaOptional('exclude'): bool
}], model, _logger)
schema.validate(config_list)
def reset_tools(self):
if self.data_collector is None:
self.data_collector = WeightDataCollector(self)
else:
self.data_collector.reset()
if self.metrics_calculator is None:
self.metrics_calculator = DistMetricsCalculator(p=2, dim=0)
if self.sparsity_allocator is None:
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self, dim=0)
elif self.mode == 'dependency_aware':
self.sparsity_allocator = Conv2dDependencyAwareAllocator(self, 0, self.dummy_input)
else:
raise NotImplementedError('Only support mode `normal` and `dependency_aware`')
class SlimPruner(OneShotPruner):
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor],
training_epochs: int, scale: float = 0.0001, mode='global'):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- total_sparsity : This is to specify the total sparsity for all layers in this config,
each layer may have different sparsity.
- max_sparsity_per_layer : Always used with total_sparsity. Limit the max sparsity of each layer.
- op_types : Only BatchNorm2d is supported in SlimPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
optimizer
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data,
so do not use this optimizer in other places.
criterion
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_epochs
The epoch number for training model to sparsify the BN weight.
mode
'normal' or 'global'.
If prune the model in a global way, all layer weights with same config will be considered uniformly.
That means a single layer may not reach or exceed the sparsity setting in config,
but the total pruned weights meet the sparsity setting.
"""
self.mode = mode
self.trainer = trainer
self.optimizer = optimizer
self.criterion = criterion
self.training_epochs = training_epochs
self._scale = scale
super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('total_sparsity'): And(float, lambda n: 0 < n < 1),
SchemaOptional('max_sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('op_types'): ['BatchNorm2d'],
SchemaOptional('op_names'): [str],
SchemaOptional('exclude'): bool
}], model, _logger)
schema.validate(config_list)
def criterion_patch(self, criterion: Callable[[Tensor, Tensor], Tensor]) -> Callable[[Tensor, Tensor], Tensor]:
def patched_criterion(input_tensor: Tensor, target: Tensor):
sum_l1 = 0
for _, wrapper in self.get_modules_wrapper().items():
sum_l1 += torch.norm(wrapper.module.weight.data, p=1)
return criterion(input_tensor, target) + self._scale * sum_l1
return patched_criterion
def reset_tools(self):
if self.data_collector is None:
self.data_collector = WeightTrainerBasedDataCollector(self, self.trainer, self.optimizer, self.criterion,
self.training_epochs, criterion_patch=self.criterion_patch)
else:
self.data_collector.reset()
if self.metrics_calculator is None:
self.metrics_calculator = NormMetricsCalculator()
if self.sparsity_allocator is None:
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self)
elif self.mode == 'global':
self.sparsity_allocator = GlobalSparsityAllocator(self)
else:
raise NotImplementedError('Only support mode `normal` and `global`')
class ActivationPruner(OneShotPruner):
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int, activation: str = 'relu',
mode: str = 'normal', dummy_input: Optional[Tensor] = None):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- op_types : Conv2d and Linear are supported in ActivationPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
optimizer
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data,
so do not use this optimizer in other places.
criterion
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_batches
The batch number used to collect activations.
mode
'normal' or 'dependency_aware'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the activation-based metrics and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
self.mode = mode
self.dummy_input = dummy_input
self.trainer = trainer
self.optimizer = optimizer
self.criterion = criterion
self.training_batches = training_batches
self._activation = self._choose_activation(activation)
super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('op_types'): ['Conv2d', 'Linear'],
SchemaOptional('op_names'): [str],
SchemaOptional('exclude'): bool
}], model, _logger)
schema.validate(config_list)
def _choose_activation(self, activation: str = 'relu') -> Callable:
if activation == 'relu':
return nn.functional.relu
elif activation == 'relu6':
return nn.functional.relu6
else:
raise 'Unsupported activatoin {}'.format(activation)
def _collector(self, buffer: List) -> Callable[[Module, Tensor, Tensor], None]:
def collect_activation(_module: Module, _input: Tensor, output: Tensor):
if len(buffer) < self.training_batches:
buffer.append(self._activation(output.detach()))
return collect_activation
def reset_tools(self):
collector_info = HookCollectorInfo([layer_info for layer_info, _ in self._detect_modules_to_compress()], 'forward', self._collector)
if self.data_collector is None:
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer, self.criterion,
1, collector_infos=[collector_info])
else:
self.data_collector.reset()
if self.metrics_calculator is None:
self.metrics_calculator = self._get_metrics_calculator()
if self.sparsity_allocator is None:
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self, dim=0)
elif self.mode == 'dependency_aware':
self.sparsity_allocator = Conv2dDependencyAwareAllocator(self, 0, self.dummy_input)
else:
raise NotImplementedError('Only support mode `normal` and `dependency_aware`')
def _get_metrics_calculator(self) -> MetricsCalculator:
raise NotImplementedError()
class ActivationAPoZRankPruner(ActivationPruner):
def _get_metrics_calculator(self) -> MetricsCalculator:
return APoZRankMetricsCalculator(dim=1)
class ActivationMeanRankPruner(ActivationPruner):
def _get_metrics_calculator(self) -> MetricsCalculator:
return MeanRankMetricsCalculator(dim=1)
class TaylorFOWeightPruner(OneShotPruner):
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int,
mode: str = 'normal', dummy_input: Optional[Tensor] = None):
"""
Parameters
----------
model
Model to be pruned
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- total_sparsity : This is to specify the total sparsity for all layers in this config,
each layer may have different sparsity.
- max_sparsity_per_layer : Always used with total_sparsity. Limit the max sparsity of each layer.
- op_types : Conv2d and Linear are supported in TaylorFOWeightPruner.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
optimizer
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data,
so do not use this optimizer in other places.
criterion
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_batches
The batch number used to collect activations.
mode
'normal', 'dependency_aware' or 'global'.
If prune the model in a dependency-aware way, this pruner will
prune the model according to the taylorFO and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if set 'dependency_aware'
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
If prune the model in a global way, all layer weights with same config will be considered uniformly.
That means a single layer may not reach or exceed the sparsity setting in config,
but the total pruned weights meet the sparsity setting.
dummy_input
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
self.mode = mode
self.dummy_input = dummy_input
self.trainer = trainer
self.optimizer = optimizer
self.criterion = criterion
self.training_batches = training_batches
super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('total_sparsity'): And(float, lambda n: 0 < n < 1),
SchemaOptional('max_sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('op_types'): ['Conv2d', 'Linear'],
SchemaOptional('op_names'): [str],
SchemaOptional('exclude'): bool
}], model, _logger)
schema.validate(config_list)
def _collector(self, buffer: List, weight_tensor: Tensor) -> Callable[[Module, Tensor, Tensor], None]:
def collect_taylor(grad: Tensor):
if len(buffer) < self.training_batches:
buffer.append(self._calculate_taylor_expansion(weight_tensor, grad))
return collect_taylor
def _calculate_taylor_expansion(self, weight_tensor: Tensor, grad: Tensor) -> Tensor:
return (weight_tensor.detach() * grad.detach()).data.pow(2)
def reset_tools(self):
hook_targets = {layer_info.name: layer_info.module.weight for layer_info, _ in self._detect_modules_to_compress()}
collector_info = HookCollectorInfo(hook_targets, 'tensor', self._collector)
if self.data_collector is None:
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer, self.criterion,
1, collector_infos=[collector_info])
else:
self.data_collector.reset()
if self.metrics_calculator is None:
self.metrics_calculator = MultiDataNormMetricsCalculator(p=1, dim=0)
if self.sparsity_allocator is None:
if self.mode == 'normal':
self.sparsity_allocator = NormalSparsityAllocator(self, dim=0)
elif self.mode == 'global':
self.sparsity_allocator = GlobalSparsityAllocator(self, dim=0)
elif self.mode == 'dependency_aware':
self.sparsity_allocator = Conv2dDependencyAwareAllocator(self, 0, self.dummy_input)
else:
raise NotImplementedError('Only support mode `normal`, `global` and `dependency_aware`')
from .base import (
HookCollectorInfo,
DataCollector,
MetricsCalculator,
SparsityAllocator
)
from .data_collector import (
WeightDataCollector,
WeightTrainerBasedDataCollector,
SingleHookTrainerBasedDataCollector
)
from .metrics_calculator import (
NormMetricsCalculator,
MultiDataNormMetricsCalculator,
DistMetricsCalculator,
APoZRankMetricsCalculator,
MeanRankMetricsCalculator
)
from .sparsity_allocator import (
NormalSparsityAllocator,
GlobalSparsityAllocator,
Conv2dDependencyAwareAllocator
)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import types
from typing import List, Dict, Optional, Callable, Union
import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from nni.algorithms.compression.v2.pytorch.base import Compressor, LayerInfo
_logger = logging.getLogger(__name__)
__all__ = ['DataCollector', 'TrainerBasedDataCollector', 'HookCollectorInfo', 'MetricsCalculator', 'SparsityAllocator']
class DataCollector:
"""
An abstract class for collect the data needed by the compressor.
"""
def __init__(self, compressor: Compressor):
"""
Parameters
----------
compressor
The compressor binded with this DataCollector.
"""
self.compressor = compressor
def reset(self):
"""
Reset the `DataCollector`.
"""
raise NotImplementedError()
def collect(self) -> Dict:
"""
Collect the compressor needed data, i.e., module weight, the output of activation function.
Returns
-------
Dict
Usually has format like {module_name: tensor_type_data}.
"""
raise NotImplementedError()
class HookCollectorInfo:
def __init__(self, targets: Union[Dict[str, Tensor], List[LayerInfo]], hook_type: str,
collector: Union[Callable[[List, Tensor], Callable[[Tensor], None]], Callable[[List], Callable[[Module, Tensor, Tensor], None]]]):
"""
This class used to aggregate the information of what kind of hook is placed on which layers.
Parameters
----------
targets
List of LayerInfo or Dict of {layer_name: weight_tensor}, the hook targets.
hook_type
'forward' or 'backward'.
collector
A hook function generator, the input is a buffer (empty list) or a buffer (empty list) and tensor, the output is a hook function.
The buffer is used to store the data wanted to hook.
"""
self.targets = targets
self.hook_type = hook_type
self.collector = collector
class TrainerBasedDataCollector(DataCollector):
"""
This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks.
"""
def __init__(self, compressor: Compressor, trainer: Callable[[Module, Optimizer, Callable], None], optimizer: Optimizer,
criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
opt_before_tasks: List = [], opt_after_tasks: List = [],
collector_infos: List[HookCollectorInfo] = [], criterion_patch: Callable[[Callable], Callable] = None):
"""
Parameters
----------
compressor
The compressor binded with this DataCollector.
trainer
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
optimizer
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data,
so do not use this optimizer in other places.
criterion
The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_epochs
The total number of calling trainer.
opt_before_tasks
A list of function that will be called one by one before origin `optimizer.step()`.
Note that these functions will be patched into `optimizer.step()`.
opt_after_tasks
A list of function that will be called one by one after origin `optimizer.step()`.
Note that these functions will be patched into `optimizer.step()`.
collector_infos
A list of `HookCollectorInfo` instance. And the hooks will be registered in `__init__`.
criterion_patch
A callable function used to patch the criterion. Take a criterion function as input and return a new one.
Example::
def criterion_patch(criterion: Callable[[Tensor, Tensor], Tensor]) -> Callable[[Tensor, Tensor], Tensor]:
weight = ...
def patched_criterion(output, target):
return criterion(output, target) + torch.norm(weight)
return patched_criterion
"""
super().__init__(compressor)
self.trainer = trainer
self.training_epochs = training_epochs
self._origin_optimizer = optimizer
self._origin_criterion = criterion
self._opt_before_tasks = opt_before_tasks
self._opt_after_tasks = opt_after_tasks
self._collector_infos = collector_infos
self._criterion_patch = criterion_patch
self.reset()
def reset(self):
# refresh optimizer and criterion
self.compressor._unwrap_model()
if self._origin_optimizer is not None:
optimizer_cls = self._origin_optimizer.__class__
if optimizer_cls.__name__ == 'SGD':
self.optimizer = optimizer_cls(self.compressor.bound_model.parameters(), lr=0.001)
else:
self.optimizer = optimizer_cls(self.compressor.bound_model.parameters())
self.optimizer.load_state_dict(self._origin_optimizer.state_dict())
else:
self.optimizer = None
if self._criterion_patch is not None:
self.criterion = self._criterion_patch(self._origin_criterion)
else:
self.criterion = self._origin_criterion
self.compressor._wrap_model()
# patch optimizer
self._patch_optimizer()
# hook
self._remove_all_hook()
self._hook_id = 0
self._hook_handles = {}
self._hook_buffer = {}
self._add_all_hook()
def _patch_optimizer(self):
def patch_step(old_step):
def new_step(_, *args, **kwargs):
for task in self._opt_before_tasks:
task()
# call origin optimizer step method
output = old_step(*args, **kwargs)
for task in self._opt_after_tasks:
task()
return output
return new_step
if self.optimizer is not None:
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)
def _add_hook(self, collector_info: HookCollectorInfo) -> int:
self._hook_id += 1
self._hook_handles[self._hook_id] = {}
self._hook_buffer[self._hook_id] = {}
if collector_info.hook_type == 'forward':
self._add_forward_hook(self._hook_id, collector_info.targets, collector_info.collector)
elif collector_info.hook_type == 'backward':
self._add_backward_hook(self._hook_id, collector_info.targets, collector_info.collector)
elif collector_info.hook_type == 'tensor':
self._add_tensor_hook(self._hook_id, collector_info.targets, collector_info.collector)
else:
_logger.warning('Skip unsupported hook type: %s', collector_info.hook_type)
return self._hook_id
def _add_forward_hook(self, hook_id: int, layers: List[LayerInfo],
collector: Callable[[List], Callable[[Module, Tensor, Tensor], None]]):
assert all(isinstance(layer_info, LayerInfo) for layer_info in layers)
for layer in layers:
self._hook_buffer[hook_id][layer.name] = []
handle = layer.module.register_forward_hook(collector(self._hook_buffer[hook_id][layer.name]))
self._hook_handles[hook_id][layer.name] = handle
def _add_backward_hook(self, hook_id: int, layers: List[LayerInfo],
collector: Callable[[List], Callable[[Module, Tensor, Tensor], None]]):
assert all(isinstance(layer_info, LayerInfo) for layer_info in layers)
for layer in layers:
self._hook_buffer[hook_id][layer.name] = []
handle = layer.module.register_backward_hook(collector(self._hook_buffer[hook_id][layer.name]))
self._hook_handles[hook_id][layer.name] = handle
def _add_tensor_hook(self, hook_id: int, tensors: Dict[str, Tensor],
collector: Callable[[List, Tensor], Callable[[Tensor], None]]):
assert all(isinstance(tensor, Tensor) for _, tensor in tensors.items())
for layer_name, tensor in tensors.items():
self._hook_buffer[hook_id][layer_name] = []
handle = tensor.register_hook(collector(self._hook_buffer[hook_id][layer_name], tensor))
self._hook_handles[hook_id][layer_name] = handle
def _remove_hook(self, hook_id: int):
if hook_id not in self._hook_handles:
raise ValueError("%s is not a valid collector id" % str(hook_id))
for handle in self._hook_handles[hook_id]:
handle.remove()
del self._hook_handles[hook_id]
def _add_all_hook(self):
for collector_info in self._collector_infos:
self._add_hook(collector_info)
def _remove_all_hook(self):
if hasattr(self, '_hook_handles'):
for hook_id in list(self._hook_handles.keys()):
self._remove_hook(hook_id)
class MetricsCalculator:
"""
An abstract class for calculate a kind of metrics of the given data.
"""
def __init__(self, dim: Optional[Union[int, List[int]]] = None,
block_sparse_size: Optional[Union[int, List[int]]] = None):
"""
Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
Example:
If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2].
Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Case 1: Directly collect the conv module weight as data to calculate the metric.
Then the data has size (32, 16, 3, 3).
Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=0` will set in `__init__`.
Case 2: Use the output of the conv module as data to calculate the metric.
Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2).
Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=1` will set in `__init__`.
In both of these two case, the metric of this module has size (32,).
block_sparse_size
This used to describe the block size a metric value represented. By default, None means the block size is ones(len(dim)).
Make sure len(dim) == len(block_sparse_size), and the block_sparse_size dimension position is corresponding to dim.
Example:
The under pruning weight size is (768, 768), and you want to apply a block sparse on dim=[0] with block size [64, 768],
then you can set block_sparse_size=[64]. The final metric size is (12,).
"""
self.dim = dim if not isinstance(dim, int) else [dim]
self.block_sparse_size = block_sparse_size if not isinstance(block_sparse_size, int) else [block_sparse_size]
if self.block_sparse_size is not None:
assert all(i >= 1 for i in self.block_sparse_size)
elif self.dim is not None:
self.block_sparse_size = [1] * len(self.dim)
if self.dim is not None:
assert all(i >= 0 for i in self.dim)
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size))))
def calculate_metrics(self, data: Dict) -> Dict[str, Tensor]:
"""
Parameters
----------
data
A dict handle the data used to calculate metrics. Usually has format like {module_name: tensor_type_data}.
Returns
-------
Dict[str, Tensor]
The key is the layer_name, value is the metric.
Note that the metric has the same size with the data size on `dim`.
"""
raise NotImplementedError()
class SparsityAllocator:
"""
An abstract class for allocate mask based on metrics.
"""
def __init__(self, pruner: Compressor, dim: Optional[Union[int, List[int]]] = None,
block_sparse_size: Optional[Union[int, List[int]]] = None):
"""
Parameters
----------
pruner
The pruner that binded with this `SparsityAllocator`.
dim
The under pruning weight dimensions, which metric size should equal to the under pruning weight size on these dimensions.
None means one-to-one correspondence between pruned dimensions and metric, which equal to set `dim` as all under pruning weight dimensions.
The mask will expand to the weight size depend on `dim`.
Example:
The under pruning weight has size (2, 3, 4), and `dim=1` means the under pruning weight dimension is 1.
Then the metric should have a size (3,), i.e., `metric=[0.9, 0.1, 0.8]`.
Assuming by some kind of `SparsityAllocator` get the mask on weight dimension 1 `mask=[1, 0, 1]`,
then the dimension mask will expand to the final mask `[[[1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]], [[1, 1, 1, 1], [0, 0, 0, 0], [1, 1, 1, 1]]]`.
block_sparse_size
This used to describe the block size a metric value represented. By default, None means the block size is ones(len(dim)).
Make sure len(dim) == len(block_sparse_size), and the block_sparse_size dimension position is corresponding to dim.
Example:
The metric size is (12,), and block_sparse_size=[64], then the mask will expand to (768,) at first before expand with `dim`.
"""
self.pruner = pruner
self.dim = dim if not isinstance(dim, int) else [dim]
self.block_sparse_size = block_sparse_size if not isinstance(block_sparse_size, int) else [block_sparse_size]
if self.block_sparse_size is not None:
assert all(i >= 1 for i in self.block_sparse_size)
elif self.dim is not None:
self.block_sparse_size = [1] * len(self.dim)
if self.dim is not None:
assert all(i >= 0 for i in self.dim)
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size))))
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
"""
Parameters
----------
metrics
A metric dict. The key is the name of layer, the value is its metric.
"""
raise NotImplementedError()
def _expand_mask(self, name: str, mask: Tensor) -> Dict[str, Tensor]:
"""
Parameters
----------
name
The masked module name.
mask
The reduced mask with `self.dim` and `self.block_sparse_size`.
Returns
-------
Dict[str, Tensor]
The key is `weight_mask` or `bias_mask`, value is the final mask.
"""
weight_mask = mask.clone()
if self.block_sparse_size is not None:
# expend mask with block_sparse_size
expand_size = list(weight_mask.size())
reshape_size = list(weight_mask.size())
for i, block_width in reversed(list(enumerate(self.block_sparse_size))):
weight_mask = weight_mask.unsqueeze(i + 1)
expand_size.insert(i + 1, block_width)
reshape_size[i] *= block_width
weight_mask = weight_mask.expand(expand_size).reshape(reshape_size)
wrapper = self.pruner.get_modules_wrapper()[name]
weight_size = wrapper.module.weight.data.size()
if self.dim is None:
assert weight_mask.size() == weight_size
expand_mask = {'weight_mask': weight_mask}
else:
# expand mask to weight size with dim
assert len(weight_mask.size()) == len(self.dim)
assert all(weight_size[j] == weight_mask.size(i) for i, j in enumerate(self.dim))
idxs = list(range(len(weight_size)))
[idxs.pop(i) for i in reversed(self.dim)]
for i in idxs:
weight_mask = weight_mask.unsqueeze(i)
expand_mask = {'weight_mask': weight_mask.expand(weight_size).clone()}
# NOTE: assume we only mask output, so the mask and bias have a one-to-one correspondence.
# If we support more kind of masks, this place need refactor.
if wrapper.bias_mask is not None and weight_mask.size() == wrapper.bias_mask.size():
expand_mask['bias_mask'] = weight_mask.clone()
return expand_mask
def _compress_mask(self, mask: Tensor) -> Tensor:
"""
Parameters
----------
name
The masked module name.
mask
The entire mask has the same size with weight.
Returns
-------
Tensor
Reduce the mask with `self.dim` and `self.block_sparse_size`.
"""
if self.dim is None or len(mask.size()) == 1:
mask = mask.clone()
else:
mask_dim = list(range(len(mask.size())))
for dim in self.dim:
mask_dim.remove(dim)
mask = torch.sum(mask, dim=mask_dim)
if self.block_sparse_size is not None:
# operation like pooling
lower_case_letters = 'abcdefghijklmnopqrstuvwxyz'
ein_expression = ''
for i, step in enumerate(self.block_sparse_size):
mask = mask.unfold(i, step, step)
ein_expression += lower_case_letters[i]
ein_expression = '...{},{}'.format(ein_expression, ein_expression)
mask = torch.einsum(ein_expression, mask, torch.ones(self.block_sparse_size))
return (mask != 0).type_as(mask)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Dict, List
from torch import Tensor
from .base import DataCollector, TrainerBasedDataCollector
_logger = logging.getLogger(__name__)
__all__ = ['WeightDataCollector', 'WeightTrainerBasedDataCollector', 'SingleHookTrainerBasedDataCollector']
class WeightDataCollector(DataCollector):
"""
Collect all wrapper weights.
"""
def reset(self):
pass
def collect(self) -> Dict[str, Tensor]:
data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.module.weight.data.clone().detach()
return data
class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
"""
Collect all wrapper weights after training or inference.
"""
def collect(self) -> Dict[str, Tensor]:
for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.module.weight.data.clone().detach()
return data
class SingleHookTrainerBasedDataCollector(TrainerBasedDataCollector):
"""
Add hooks and collect data during training or inference.
Single means each wrapper only has one hook to collect data.
"""
def collect(self) -> Dict[str, List[Tensor]]:
for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {}
[data.update(buffer_dict) for _, buffer_dict in self._hook_buffer.items()]
return data
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict, List, Optional, Union
import torch
from torch import Tensor
from .base import MetricsCalculator
__all__ = ['NormMetricsCalculator', 'MultiDataNormMetricsCalculator', 'DistMetricsCalculator',
'APoZRankMetricsCalculator', 'MeanRankMetricsCalculator']
class NormMetricsCalculator(MetricsCalculator):
"""
Calculate the specify norm for each tensor in data.
L1, L2, Level, Slim pruner use this to calculate metric.
"""
def __init__(self, dim: Optional[Union[int, List[int]]] = None, p: Optional[Union[int, float]] = None):
"""
Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
Example:
If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2].
Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Case 1: Directly collect the conv module weight as data to calculate the metric.
Then the data has size (32, 16, 3, 3).
Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=0` will set in `__init__`.
Case 2: Use the output of the conv module as data to calculate the metric.
Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2).
Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=1` will set in `__init__`.
In both of these two case, the metric of this module has size (32,).
p
The order of norm. None means Frobenius norm.
"""
super().__init__(dim=dim)
self.p = p if p is not None else 'fro'
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
metrics = {}
for name, tensor in data.items():
keeped_dim = list(range(len(tensor.size()))) if self.dim is None else self.dim
across_dim = list(range(len(tensor.size())))
[across_dim.pop(i) for i in reversed(keeped_dim)]
if len(across_dim) == 0:
metrics[name] = tensor.abs()
else:
metrics[name] = tensor.norm(p=self.p, dim=across_dim)
return metrics
class MultiDataNormMetricsCalculator(NormMetricsCalculator):
"""
Sum each list of tensor in data at first, then calculate the specify norm for each sumed tensor.
TaylorFO pruner use this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
new_data = {name: sum(list_tensor) for name, list_tensor in data.items()}
return super().calculate_metrics(new_data)
class DistMetricsCalculator(MetricsCalculator):
"""
Calculate the sum of specify distance for each element with all other elements in specify `dim` in each tensor in data.
FPGM pruner use this to calculate metric.
"""
def __init__(self, p: float, dim: Union[int, List[int]]):
"""
Parameters
----------
dim
The dimensions that corresponding to the under pruning weight dimensions in collected data.
None means one-to-one correspondence between pruned dimensions and data, which equal to set `dim` as all data dimensions.
Only these `dim` will be kept and other dimensions of the data will be reduced.
Example:
If you want to prune the Conv2d weight in filter level, and the weight size is (32, 16, 3, 3) [out-channel, in-channel, kernal-size-1, kernal-size-2].
Then the under pruning dimensions is [0], which means you want to prune the filter or out-channel.
Case 1: Directly collect the conv module weight as data to calculate the metric.
Then the data has size (32, 16, 3, 3).
Mention that the dimension 0 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=0` will set in `__init__`.
Case 2: Use the output of the conv module as data to calculate the metric.
Then the data has size (batch_num, 32, feature_map_size_1, feature_map_size_2).
Mention that the dimension 1 of the data is corresponding to the under pruning weight dimension 0.
So in this case, `dim=1` will set in `__init__`.
In both of these two case, the metric of this module has size (32,).
p
The order of norm.
"""
super().__init__(dim=dim)
self.p = p
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
metrics = {}
for name, tensor in data.items():
keeped_dim = list(range(len(tensor.size()))) if self.dim is None else self.dim
reorder_dim = list(keeped_dim)
reorder_dim.extend([i for i in range(len(tensor.size())) if i not in keeped_dim])
reorder_tensor = tensor.permute(*reorder_dim).clone()
metric = torch.ones(*reorder_tensor.size()[:len(keeped_dim)], device=reorder_tensor.device)
across_dim = list(range(len(keeped_dim), len(reorder_dim)))
idxs = metric.nonzero()
for idx in idxs:
other = reorder_tensor
for i in idx:
other = other[i]
other = other.clone()
if len(across_dim) == 0:
dist_sum = torch.abs(reorder_tensor - other).sum()
else:
dist_sum = torch.norm((reorder_tensor - other), p=self.p, dim=across_dim).sum()
# NOTE: this place need refactor when support layer level pruning.
tmp_metric = metric
for i in idx[:-1]:
tmp_metric = tmp_metric[i]
tmp_metric[idx[-1]] = dist_sum
metrics[name] = metric
return metrics
class APoZRankMetricsCalculator(MetricsCalculator):
"""
This metric counts the zero number at the same position in the tensor list in data,
then sum the zero number on `dim` and calculate the non-zero rate.
Note that the metric we return is (1 - apoz), because we assume a higher metric value has higher importance.
APoZRank pruner use this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
metrics = {}
for name, tensor_list in data.items():
# NOTE: dim=0 means the batch dim is 0
activations = torch.cat(tensor_list, dim=0)
_eq_zero = torch.eq(activations, torch.zeros_like(activations))
keeped_dim = list(range(len(_eq_zero.size()))) if self.dim is None else self.dim
across_dim = list(range(len(_eq_zero.size())))
[across_dim.pop(i) for i in reversed(keeped_dim)]
# The element number on each [keeped_dim + 1] in _eq_zero
total_size = 1
for dim, dim_size in enumerate(_eq_zero.size()):
if dim not in keeped_dim:
total_size *= dim_size
_apoz = torch.sum(_eq_zero, dim=across_dim, dtype=torch.float64) / total_size
# NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned.
metrics[name] = torch.ones_like(_apoz) - _apoz
return metrics
class MeanRankMetricsCalculator(MetricsCalculator):
"""
This metric simply concat the list of tensor on dim 0, and average on `dim`.
MeanRank pruner use this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
metrics = {}
for name, tensor_list in data.items():
# NOTE: dim=0 means the batch dim is 0
activations = torch.cat(tensor_list, dim=0)
keeped_dim = list(range(len(activations.size()))) if self.dim is None else self.dim
across_dim = list(range(len(activations.size())))
[across_dim.pop(i) for i in reversed(keeped_dim)]
metrics[name] = torch.mean(activations, across_dim)
return metrics
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Dict, List, Tuple, Union
import numpy as np
import torch
from torch import Tensor
from nni.algorithms.compression.v2.pytorch.base import Pruner
from nni.compression.pytorch.utils.shape_dependency import ChannelDependency, GroupDependency
from .base import SparsityAllocator
class NormalSparsityAllocator(SparsityAllocator):
"""
This allocator simply pruned the weight with smaller metrics in layer level.
"""
def generate_sparsity(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
masks = {}
for name, wrapper in self.pruner.get_modules_wrapper().items():
sparsity_rate = wrapper.config['sparsity_per_layer']
assert name in metrics, 'Metric of %s is not calculated.'
metric = metrics[name] * self._compress_mask(wrapper.weight_mask)
prune_num = int(sparsity_rate * metric.numel())
if prune_num == 0:
continue
threshold = torch.topk(metric.view(-1), prune_num, largest=False)[0].max()
mask = torch.gt(metric, threshold).type_as(metric)
masks[name] = self._expand_mask(name, mask)
return masks
class GlobalSparsityAllocator(SparsityAllocator):
"""
This allocator pruned the weight with smaller metrics in group level.
This means all layers in a group will sort metrics uniformly.
The layers with the same config in config_list is a group.
"""
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
masks = {}
# {group_index: {layer_name: metric}}
grouped_metrics = {idx: {name: metrics[name] for name in names}
for idx, names in self.pruner.generate_module_groups().items()}
for _, group_metric_dict in grouped_metrics.items():
threshold, sub_thresholds = self._calculate_threshold(group_metric_dict)
for name, metric in group_metric_dict.items():
mask = torch.gt(metric, min(threshold, sub_thresholds[name])).type_as(metric)
masks[name] = self._expand_mask(name, mask)
return masks
def _calculate_threshold(self, group_metric_dict: Dict[str, Tensor]) -> Tuple[float, Dict[str, float]]:
metric_list = []
sub_thresholds = {}
total_weight_num = 0
temp_wrapper_config = self.pruner.get_modules_wrapper()[list(group_metric_dict.keys())[0]].config
total_sparsity = temp_wrapper_config['total_sparsity']
max_sparsity_per_layer = temp_wrapper_config.get('max_sparsity_per_layer', 1.0)
for name, metric in group_metric_dict.items():
wrapper = self.pruner.get_modules_wrapper()[name]
metric = metric * self._compress_mask(wrapper.weight_mask)
print(metric)
layer_weight_num = wrapper.module.weight.data.numel()
stay_num = int(metric.numel() * max_sparsity_per_layer)
# Remove the weight parts that must be left
stay_metric = torch.topk(metric.view(-1), stay_num, largest=False)[0]
sub_thresholds[name] = stay_metric.max()
expend_times = int(layer_weight_num / metric.numel())
if expend_times > 1:
stay_metric = stay_metric.expand(stay_num, int(layer_weight_num / metric.numel())).view(-1)
metric_list.append(stay_metric)
total_weight_num += layer_weight_num
assert total_sparsity <= max_sparsity_per_layer, 'total_sparsity should less than max_sparsity_per_layer.'
total_prune_num = int(total_sparsity * total_weight_num)
threshold = torch.topk(torch.cat(metric_list).view(-1), total_prune_num, largest=False)[0].max().item()
return threshold, sub_thresholds
class Conv2dDependencyAwareAllocator(SparsityAllocator):
"""
A specify allocator for Conv2d with dependency aware.
"""
def __init__(self, pruner: Pruner, dim: int, dummy_input: Any):
assert isinstance(dim, int), 'Only support single dim in Conv2dDependencyAwareAllocator.'
super().__init__(pruner, dim=dim)
self.dummy_input = dummy_input
def _get_dependency(self):
graph = self.pruner.generate_graph(dummy_input=self.dummy_input)
self.channel_depen = ChannelDependency(traced_model=graph.trace).dependency_sets
self.group_depen = GroupDependency(traced_model=graph.trace).dependency_sets
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
self._get_dependency()
masks = {}
grouped_metrics = {}
for idx, names in enumerate(self.channel_depen):
grouped_metric = {name: metrics[name] * self._compress_mask(self.pruner.get_modules_wrapper()[name].weight_mask) for name in names if name in metrics}
if len(grouped_metric) > 0:
grouped_metrics[idx] = grouped_metric
for _, group_metric_dict in grouped_metrics.items():
group_metric = self._group_metric_calculate(group_metric_dict)
sparsities = {name: self.pruner.get_modules_wrapper()[name].config['sparsity_per_layer'] for name in group_metric_dict.keys()}
min_sparsity = min(sparsities.values())
conv2d_groups = [self.group_depen[name] for name in group_metric_dict.keys()]
max_conv2d_group = np.lcm.reduce(conv2d_groups)
pruned_per_conv2d_group = int(group_metric.numel() / max_conv2d_group * min_sparsity)
conv2d_group_step = int(group_metric.numel() / max_conv2d_group)
group_mask = []
for gid in range(max_conv2d_group):
_start = gid * conv2d_group_step
_end = (gid + 1) * conv2d_group_step
if pruned_per_conv2d_group > 0:
threshold = torch.topk(group_metric[_start: _end], pruned_per_conv2d_group, largest=False)[0].max()
conv2d_group_mask = torch.gt(group_metric[_start:_end], threshold).type_as(group_metric)
else:
conv2d_group_mask = torch.ones(conv2d_group_step, device=group_metric.device)
group_mask.append(conv2d_group_mask)
group_mask = torch.cat(group_mask, dim=0)
for name, metric in group_metric_dict.items():
metric = (metric - metric.min()) * group_mask
pruned_num = int(sparsities[name] * len(metric))
threshold = torch.topk(metric, pruned_num, largest=False)[0].max()
mask = torch.gt(metric, threshold).type_as(metric)
masks[name] = self._expand_mask(name, mask)
return masks
def _group_metric_calculate(self, group_metrics: Union[Dict[str, Tensor], List[Tensor]]) -> Tensor:
"""
Add all metric value in the same position in one group.
"""
group_metrics = list(group_metrics.values()) if isinstance(group_metrics, dict) else group_metrics
assert all(group_metrics[0].size() == group_metric.size() for group_metric in group_metrics), 'Metrics size do not match.'
group_sum_metric = torch.zeros(group_metrics[0].size(), device=group_metrics[0].device)
for group_metric in group_metrics:
group_sum_metric += group_metric
return group_sum_metric
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from schema import Schema, And, SchemaError
def validate_op_names(model, op_names, logger):
found_names = set(map(lambda x: x[0], model.named_modules()))
not_found_op_names = list(set(op_names) - found_names)
if not_found_op_names:
logger.warning('op_names %s not found in model', not_found_op_names)
return True
def validate_op_types(model, op_types, logger):
found_types = set(['default']) | set(map(lambda x: type(x[1]).__name__, model.named_modules()))
not_found_op_types = list(set(op_types) - found_types)
if not_found_op_types:
logger.warning('op_types %s not found in model', not_found_op_types)
return True
def validate_op_types_op_names(data):
if not ('op_types' in data or 'op_names' in data):
raise SchemaError('Either op_types or op_names must be specified.')
return True
class CompressorSchema:
def __init__(self, data_schema, model, logger):
assert isinstance(data_schema, list) and len(data_schema) <= 1
self.data_schema = data_schema
self.compressor_schema = Schema(self._modify_schema(data_schema, model, logger))
def _modify_schema(self, data_schema, model, logger):
if not data_schema:
return data_schema
for k in data_schema[0]:
old_schema = data_schema[0][k]
if k == 'op_types' or (isinstance(k, Schema) and k._schema == 'op_types'):
new_schema = And(old_schema, lambda n: validate_op_types(model, n, logger))
data_schema[0][k] = new_schema
if k == 'op_names' or (isinstance(k, Schema) and k._schema == 'op_names'):
new_schema = And(old_schema, lambda n: validate_op_names(model, n, logger))
data_schema[0][k] = new_schema
data_schema[0] = And(data_schema[0], lambda d: validate_op_types_op_names(d))
return data_schema
def validate(self, data):
self.compressor_schema.validate(data)
def validate_exclude_sparsity(data):
if not ('exclude' in data or 'sparsity_per_layer' in data or 'total_sparsity' in data):
raise SchemaError('One of [sparsity_per_layer, total_sparsity, exclude] should be specified.')
return True
def validate_exclude_quant_types_quant_bits(data):
if not ('exclude' in data or ('quant_types' in data and 'quant_bits' in data)):
raise SchemaError('Either (quant_types and quant_bits) or exclude must be specified.')
return True
class PrunerSchema(CompressorSchema):
def _modify_schema(self, data_schema, model, logger):
data_schema = super()._modify_schema(data_schema, model, logger)
data_schema[0] = And(data_schema[0], lambda d: validate_exclude_sparsity(d))
return data_schema
class QuantizerSchema(CompressorSchema):
def _modify_schema(self, data_schema, model, logger):
data_schema = super()._modify_schema(data_schema, model, logger)
data_schema[0] = And(data_schema[0], lambda d: validate_exclude_quant_types_quant_bits(d))
return data_schema
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment