Unverified Commit 2b9f5f8c authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Model Compression] update config list key (#4074)

parent 862c67df
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from copy import deepcopy
import logging
from typing import List, Dict, Tuple, Callable, Optional
from schema import And, Optional as SchemaOptional
from schema import And, Or, Optional as SchemaOptional
import torch
from torch import Tensor
import torch.nn as nn
......@@ -12,7 +13,8 @@ from torch.nn import Module
from torch.optim import Optimizer
from nni.algorithms.compression.v2.pytorch.base.pruner import Pruner
from nni.algorithms.compression.v2.pytorch.utils.config_validation import PrunerSchema
from nni.algorithms.compression.v2.pytorch.utils.config_validation import CompressorSchema
from nni.algorithms.compression.v2.pytorch.utils.pruning import config_list_canonical
from .tools import (
DataCollector,
......@@ -43,26 +45,47 @@ _logger = logging.getLogger(__name__)
__all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPruner', 'ActivationPruner',
'ActivationAPoZRankPruner', 'ActivationMeanRankPruner', 'TaylorFOWeightPruner']
NORMAL_SCHEMA = {
Or('sparsity', 'sparsity_per_layer'): And(float, lambda n: 0 <= n < 1),
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str]
}
GLOBAL_SCHEMA = {
'total_sparsity': And(float, lambda n: 0 <= n < 1),
SchemaOptional('max_sparsity_per_layer'): And(float, lambda n: 0 < n <= 1),
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str]
}
EXCLUDE_SCHEMA = {
'exclude': bool,
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str]
}
INTERNAL_SCHEMA = {
'total_sparsity': And(float, lambda n: 0 <= n < 1),
SchemaOptional('max_sparsity_per_layer'): {str: float},
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str]
}
class OneShotPruner(Pruner):
def __init__(self, model: Module, config_list: List[Dict]):
self.data_collector: DataCollector = None
self.metrics_calculator: MetricsCalculator = None
self.sparsity_allocator: SparsityAllocator = None
self._convert_config_list(config_list)
super().__init__(model, config_list)
def _convert_config_list(self, config_list: List[Dict]):
"""
Convert `sparsity` in config to `sparsity_per_layer`.
"""
for config in config_list:
if 'sparsity' in config:
if 'sparsity_per_layer' in config:
raise ValueError("'sparsity' and 'sparsity_per_layer' have the same semantics, can not set both in one config.")
else:
config['sparsity_per_layer'] = config.pop('sparsity')
def validate_config(self, model: Module, config_list: List[Dict]):
self._validate_config_before_canonical(model, config_list)
self.config_list = config_list_canonical(model, config_list)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
pass
def reset(self, model: Optional[Module], config_list: Optional[List[Dict]]):
super().reset(model=model, config_list=config_list)
......@@ -115,14 +138,9 @@ class LevelPruner(OneShotPruner):
self.mode = 'normal'
super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str],
SchemaOptional('exclude'): bool
}], model, _logger)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
schema = CompressorSchema(schema_list, model, _logger)
schema.validate(config_list)
def reset_tools(self):
......@@ -171,13 +189,11 @@ class NormPruner(OneShotPruner):
self.dummy_input = dummy_input
super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('op_types'): ['Conv2d', 'Linear'],
SchemaOptional('op_names'): [str],
SchemaOptional('exclude'): bool
}], model, _logger)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
for sub_shcema in schema_list:
sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear']
schema = CompressorSchema(schema_list, model, _logger)
schema.validate(config_list)
......@@ -291,13 +307,11 @@ class FPGMPruner(OneShotPruner):
self.dummy_input = dummy_input
super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('op_types'): ['Conv2d', 'Linear'],
SchemaOptional('op_names'): [str],
SchemaOptional('exclude'): bool
}], model, _logger)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
for sub_shcema in schema_list:
sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear']
schema = CompressorSchema(schema_list, model, _logger)
schema.validate(config_list)
......@@ -376,15 +390,15 @@ class SlimPruner(OneShotPruner):
self._scale = scale
super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('total_sparsity'): And(float, lambda n: 0 < n < 1),
SchemaOptional('max_sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('op_types'): ['BatchNorm2d'],
SchemaOptional('op_names'): [str],
SchemaOptional('exclude'): bool
}], model, _logger)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema_list = [deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
if self.mode == 'global':
schema_list.append(deepcopy(GLOBAL_SCHEMA))
else:
schema_list.append(deepcopy(NORMAL_SCHEMA))
for sub_shcema in schema_list:
sub_shcema[SchemaOptional('op_types')] = ['BatchNorm2d']
schema = CompressorSchema(schema_list, model, _logger)
schema.validate(config_list)
......@@ -477,13 +491,11 @@ class ActivationPruner(OneShotPruner):
self._activation = self._choose_activation(activation)
super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('op_types'): ['Conv2d', 'Linear'],
SchemaOptional('op_names'): [str],
SchemaOptional('exclude'): bool
}], model, _logger)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
for sub_shcema in schema_list:
sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear']
schema = CompressorSchema(schema_list, model, _logger)
schema.validate(config_list)
......@@ -603,19 +615,19 @@ class TaylorFOWeightPruner(OneShotPruner):
self.training_batches = training_batches
super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('total_sparsity'): And(float, lambda n: 0 < n < 1),
SchemaOptional('max_sparsity_per_layer'): And(float, lambda n: 0 < n < 1),
SchemaOptional('op_types'): ['Conv2d', 'Linear'],
SchemaOptional('op_names'): [str],
SchemaOptional('exclude'): bool
}], model, _logger)
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema_list = [deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
if self.mode == 'global':
schema_list.append(deepcopy(GLOBAL_SCHEMA))
else:
schema_list.append(deepcopy(NORMAL_SCHEMA))
for sub_shcema in schema_list:
sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear']
schema = CompressorSchema(schema_list, model, _logger)
schema.validate(config_list)
def _collector(self, buffer: List, weight_tensor: Tensor) -> Callable[[Module, Tensor, Tensor], None]:
def _collector(self, buffer: List, weight_tensor: Tensor) -> Callable[[Tensor], None]:
def collect_taylor(grad: Tensor):
if len(buffer) < self.training_batches:
buffer.append(self._calculate_taylor_expansion(weight_tensor, grad))
......
......@@ -437,6 +437,6 @@ class SparsityAllocator:
mask = mask.unfold(i, step, step)
ein_expression += lower_case_letters[i]
ein_expression = '...{},{}'.format(ein_expression, ein_expression)
mask = torch.einsum(ein_expression, mask, torch.ones(self.block_sparse_size))
mask = torch.einsum(ein_expression, mask, torch.ones(self.block_sparse_size).to(mask.device))
return (mask != 0).type_as(mask)
......@@ -120,7 +120,7 @@ class DistMetricsCalculator(MetricsCalculator):
metric = torch.ones(*reorder_tensor.size()[:len(keeped_dim)], device=reorder_tensor.device)
across_dim = list(range(len(keeped_dim), len(reorder_dim)))
idxs = metric.nonzero()
idxs = metric.nonzero(as_tuple=False)
for idx in idxs:
other = reorder_tensor
for i in idx:
......@@ -161,7 +161,7 @@ class APoZRankMetricsCalculator(MetricsCalculator):
for dim, dim_size in enumerate(_eq_zero.size()):
if dim not in keeped_dim:
total_size *= dim_size
_apoz = torch.sum(_eq_zero, dim=across_dim, dtype=torch.float64) / total_size
_apoz = torch.sum(_eq_zero, dim=across_dim).type_as(activations) / total_size
# NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned.
metrics[name] = torch.ones_like(_apoz) - _apoz
return metrics
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
from typing import Any, Dict, List, Tuple, Union
import numpy as np
......@@ -20,7 +21,7 @@ class NormalSparsityAllocator(SparsityAllocator):
def generate_sparsity(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
masks = {}
for name, wrapper in self.pruner.get_modules_wrapper().items():
sparsity_rate = wrapper.config['sparsity_per_layer']
sparsity_rate = wrapper.config['total_sparsity']
assert name in metrics, 'Metric of %s is not calculated.'
metric = metrics[name] * self._compress_mask(wrapper.weight_mask)
......@@ -58,26 +59,30 @@ class GlobalSparsityAllocator(SparsityAllocator):
temp_wrapper_config = self.pruner.get_modules_wrapper()[list(group_metric_dict.keys())[0]].config
total_sparsity = temp_wrapper_config['total_sparsity']
max_sparsity_per_layer = temp_wrapper_config.get('max_sparsity_per_layer', 1.0)
max_sparsity_per_layer = temp_wrapper_config.get('max_sparsity_per_layer', {})
for name, metric in group_metric_dict.items():
wrapper = self.pruner.get_modules_wrapper()[name]
metric = metric * self._compress_mask(wrapper.weight_mask)
print(metric)
layer_weight_num = wrapper.module.weight.data.numel()
stay_num = int(metric.numel() * max_sparsity_per_layer)
retention_ratio = 1 - max_sparsity_per_layer.get(name, 1)
retention_numel = math.ceil(retention_ratio * layer_weight_num)
removed_metric_num = math.ceil(retention_numel / (wrapper.weight_mask.numel() / metric.numel()))
stay_metric_num = metric.numel() - removed_metric_num
# Remove the weight parts that must be left
stay_metric = torch.topk(metric.view(-1), stay_num, largest=False)[0]
stay_metric = torch.topk(metric.view(-1), stay_metric_num, largest=False)[0]
sub_thresholds[name] = stay_metric.max()
expend_times = int(layer_weight_num / metric.numel())
if expend_times > 1:
stay_metric = stay_metric.expand(stay_num, int(layer_weight_num / metric.numel())).view(-1)
stay_metric = stay_metric.expand(stay_metric_num, int(layer_weight_num / metric.numel())).view(-1)
metric_list.append(stay_metric)
total_weight_num += layer_weight_num
assert total_sparsity <= max_sparsity_per_layer, 'total_sparsity should less than max_sparsity_per_layer.'
total_prune_num = int(total_sparsity * total_weight_num)
if total_prune_num == 0:
threshold = torch.cat(metric_list).min().item() - 1
else:
threshold = torch.topk(torch.cat(metric_list).view(-1), total_prune_num, largest=False)[0].max().item()
return threshold, sub_thresholds
......@@ -108,7 +113,7 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
for _, group_metric_dict in grouped_metrics.items():
group_metric = self._group_metric_calculate(group_metric_dict)
sparsities = {name: self.pruner.get_modules_wrapper()[name].config['sparsity_per_layer'] for name in group_metric_dict.keys()}
sparsities = {name: self.pruner.get_modules_wrapper()[name].config['total_sparsity'] for name in group_metric_dict.keys()}
min_sparsity = min(sparsities.values())
conv2d_groups = [self.group_depen[name] for name in group_metric_dict.keys()]
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from logging import Logger
from typing import Dict, List
from schema import Schema, And, SchemaError
def validate_op_names(model, op_names, logger):
found_names = set(map(lambda x: x[0], model.named_modules()))
not_found_op_names = list(set(op_names) - found_names)
if not_found_op_names:
logger.warning('op_names %s not found in model', not_found_op_names)
return True
def validate_op_types(model, op_types, logger):
found_types = set(['default']) | set(map(lambda x: type(x[1]).__name__, model.named_modules()))
from torch.nn import Module
not_found_op_types = list(set(op_types) - found_types)
if not_found_op_types:
logger.warning('op_types %s not found in model', not_found_op_types)
return True
def validate_op_types_op_names(data):
if not ('op_types' in data or 'op_names' in data):
raise SchemaError('Either op_types or op_names must be specified.')
return True
class CompressorSchema:
def __init__(self, data_schema, model, logger):
assert isinstance(data_schema, list) and len(data_schema) <= 1
def __init__(self, data_schema: List[Dict], model: Module, logger: Logger):
assert isinstance(data_schema, list)
self.data_schema = data_schema
self.compressor_schema = Schema(self._modify_schema(data_schema, model, logger))
def _modify_schema(self, data_schema, model, logger):
def _modify_schema(self, data_schema: List[Dict], model: Module, logger: Logger) -> List[Dict]:
if not data_schema:
return data_schema
for k in data_schema[0]:
old_schema = data_schema[0][k]
for i, sub_schema in enumerate(data_schema):
for k, old_schema in sub_schema.items():
if k == 'op_types' or (isinstance(k, Schema) and k._schema == 'op_types'):
new_schema = And(old_schema, lambda n: validate_op_types(model, n, logger))
data_schema[0][k] = new_schema
sub_schema[k] = new_schema
if k == 'op_names' or (isinstance(k, Schema) and k._schema == 'op_names'):
new_schema = And(old_schema, lambda n: validate_op_names(model, n, logger))
data_schema[0][k] = new_schema
sub_schema[k] = new_schema
data_schema[0] = And(data_schema[0], lambda d: validate_op_types_op_names(d))
data_schema[i] = And(sub_schema, lambda d: validate_op_types_op_names(d))
return data_schema
def validate(self, data):
self.compressor_schema.validate(data)
def validate_exclude_sparsity(data):
if not ('exclude' in data or 'sparsity_per_layer' in data or 'total_sparsity' in data):
raise SchemaError('One of [sparsity_per_layer, total_sparsity, exclude] should be specified.')
def validate_op_names(model, op_names, logger):
found_names = set(map(lambda x: x[0], model.named_modules()))
not_found_op_names = list(set(op_names) - found_names)
if not_found_op_names:
logger.warning('op_names %s not found in model', not_found_op_names)
return True
def validate_exclude_quant_types_quant_bits(data):
if not ('exclude' in data or ('quant_types' in data and 'quant_bits' in data)):
raise SchemaError('Either (quant_types and quant_bits) or exclude must be specified.')
def validate_op_types(model, op_types, logger):
found_types = set(['default']) | set(map(lambda x: type(x[1]).__name__, model.named_modules()))
not_found_op_types = list(set(op_types) - found_types)
if not_found_op_types:
logger.warning('op_types %s not found in model', not_found_op_types)
return True
class PrunerSchema(CompressorSchema):
def _modify_schema(self, data_schema, model, logger):
data_schema = super()._modify_schema(data_schema, model, logger)
data_schema[0] = And(data_schema[0], lambda d: validate_exclude_sparsity(d))
return data_schema
class QuantizerSchema(CompressorSchema):
def _modify_schema(self, data_schema, model, logger):
data_schema = super()._modify_schema(data_schema, model, logger)
data_schema[0] = And(data_schema[0], lambda d: validate_exclude_quant_types_quant_bits(d))
return data_schema
def validate_op_types_op_names(data):
if not ('op_types' in data or 'op_names' in data):
raise SchemaError('Either op_types or op_names must be specified.')
return True
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from copy import deepcopy
from typing import Dict, List
from torch.nn import Module
def config_list_canonical(model: Module, config_list: List[Dict]) -> List[Dict]:
'''
Split the config by op_names if 'sparsity' or 'sparsity_per_layer' in config,
and set the sub_config['total_sparsity'] = config['sparsity_per_layer'].
'''
for config in config_list:
if 'sparsity' in config:
if 'sparsity_per_layer' in config:
raise ValueError("'sparsity' and 'sparsity_per_layer' have the same semantics, can not set both in one config.")
else:
config['sparsity_per_layer'] = config.pop('sparsity')
config_list = dedupe_config_list(unfold_config_list(model, config_list))
new_config_list = []
for config in config_list:
if 'sparsity_per_layer' in config:
sparsity_per_layer = config.pop('sparsity_per_layer')
op_names = config.pop('op_names')
for op_name in op_names:
sub_config = deepcopy(config)
sub_config['op_names'] = [op_name]
sub_config['total_sparsity'] = sparsity_per_layer
new_config_list.append(sub_config)
elif 'max_sparsity_per_layer' in config and isinstance(config['max_sparsity_per_layer'], float):
op_names = config.get('op_names', [])
max_sparsity_per_layer = {}
max_sparsity = config['max_sparsity_per_layer']
for op_name in op_names:
max_sparsity_per_layer[op_name] = max_sparsity
config['max_sparsity_per_layer'] = max_sparsity_per_layer
new_config_list.append(config)
else:
new_config_list.append(config)
return new_config_list
def unfold_config_list(model: Module, config_list: List[Dict]) -> List[Dict]:
'''
Unfold config_list to op_names level.
'''
unfolded_config_list = []
for config in config_list:
op_names = []
for module_name, module in model.named_modules():
module_type = type(module).__name__
if 'op_types' in config and module_type not in config['op_types']:
continue
if 'op_names' in config and module_name not in config['op_names']:
continue
op_names.append(module_name)
unfolded_config = deepcopy(config)
unfolded_config['op_names'] = op_names
unfolded_config_list.append(unfolded_config)
return unfolded_config_list
def dedupe_config_list(config_list: List[Dict]) -> List[Dict]:
'''
Dedupe the op_names in unfolded config_list.
'''
exclude = set()
exclude_idxes = []
config_list = deepcopy(config_list)
for idx, config in reversed(list(enumerate(config_list))):
if 'exclude' in config:
exclude.update(config['op_names'])
exclude_idxes.append(idx)
continue
config['op_names'] = sorted(list(set(config['op_names']).difference(exclude)))
exclude.update(config['op_names'])
for idx in sorted(exclude_idxes, reverse=True):
config_list.pop(idx)
return config_list
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment