"testing/vscode:/vscode.git/clone" did not exist on "8cdc185bb4ca34fcfda70d7e329ddc30c44aadae"
Unverified Commit 2b9f5f8c authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Model Compression] update config list key (#4074)

parent 862c67df
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from copy import deepcopy
import logging import logging
from typing import List, Dict, Tuple, Callable, Optional from typing import List, Dict, Tuple, Callable, Optional
from schema import And, Optional as SchemaOptional from schema import And, Or, Optional as SchemaOptional
import torch import torch
from torch import Tensor from torch import Tensor
import torch.nn as nn import torch.nn as nn
...@@ -12,7 +13,8 @@ from torch.nn import Module ...@@ -12,7 +13,8 @@ from torch.nn import Module
from torch.optim import Optimizer from torch.optim import Optimizer
from nni.algorithms.compression.v2.pytorch.base.pruner import Pruner from nni.algorithms.compression.v2.pytorch.base.pruner import Pruner
from nni.algorithms.compression.v2.pytorch.utils.config_validation import PrunerSchema from nni.algorithms.compression.v2.pytorch.utils.config_validation import CompressorSchema
from nni.algorithms.compression.v2.pytorch.utils.pruning import config_list_canonical
from .tools import ( from .tools import (
DataCollector, DataCollector,
...@@ -43,26 +45,47 @@ _logger = logging.getLogger(__name__) ...@@ -43,26 +45,47 @@ _logger = logging.getLogger(__name__)
__all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPruner', 'ActivationPruner', __all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPruner', 'ActivationPruner',
'ActivationAPoZRankPruner', 'ActivationMeanRankPruner', 'TaylorFOWeightPruner'] 'ActivationAPoZRankPruner', 'ActivationMeanRankPruner', 'TaylorFOWeightPruner']
NORMAL_SCHEMA = {
Or('sparsity', 'sparsity_per_layer'): And(float, lambda n: 0 <= n < 1),
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str]
}
GLOBAL_SCHEMA = {
'total_sparsity': And(float, lambda n: 0 <= n < 1),
SchemaOptional('max_sparsity_per_layer'): And(float, lambda n: 0 < n <= 1),
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str]
}
EXCLUDE_SCHEMA = {
'exclude': bool,
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str]
}
INTERNAL_SCHEMA = {
'total_sparsity': And(float, lambda n: 0 <= n < 1),
SchemaOptional('max_sparsity_per_layer'): {str: float},
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str]
}
class OneShotPruner(Pruner): class OneShotPruner(Pruner):
def __init__(self, model: Module, config_list: List[Dict]): def __init__(self, model: Module, config_list: List[Dict]):
self.data_collector: DataCollector = None self.data_collector: DataCollector = None
self.metrics_calculator: MetricsCalculator = None self.metrics_calculator: MetricsCalculator = None
self.sparsity_allocator: SparsityAllocator = None self.sparsity_allocator: SparsityAllocator = None
self._convert_config_list(config_list)
super().__init__(model, config_list) super().__init__(model, config_list)
def _convert_config_list(self, config_list: List[Dict]): def validate_config(self, model: Module, config_list: List[Dict]):
""" self._validate_config_before_canonical(model, config_list)
Convert `sparsity` in config to `sparsity_per_layer`. self.config_list = config_list_canonical(model, config_list)
"""
for config in config_list: def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
if 'sparsity' in config: pass
if 'sparsity_per_layer' in config:
raise ValueError("'sparsity' and 'sparsity_per_layer' have the same semantics, can not set both in one config.")
else:
config['sparsity_per_layer'] = config.pop('sparsity')
def reset(self, model: Optional[Module], config_list: Optional[List[Dict]]): def reset(self, model: Optional[Module], config_list: Optional[List[Dict]]):
super().reset(model=model, config_list=config_list) super().reset(model=model, config_list=config_list)
...@@ -115,14 +138,9 @@ class LevelPruner(OneShotPruner): ...@@ -115,14 +138,9 @@ class LevelPruner(OneShotPruner):
self.mode = 'normal' self.mode = 'normal'
super().__init__(model, config_list) super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]): def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{ schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1), schema = CompressorSchema(schema_list, model, _logger)
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str],
SchemaOptional('exclude'): bool
}], model, _logger)
schema.validate(config_list) schema.validate(config_list)
def reset_tools(self): def reset_tools(self):
...@@ -171,13 +189,11 @@ class NormPruner(OneShotPruner): ...@@ -171,13 +189,11 @@ class NormPruner(OneShotPruner):
self.dummy_input = dummy_input self.dummy_input = dummy_input
super().__init__(model, config_list) super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]): def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{ schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1), for sub_shcema in schema_list:
SchemaOptional('op_types'): ['Conv2d', 'Linear'], sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear']
SchemaOptional('op_names'): [str], schema = CompressorSchema(schema_list, model, _logger)
SchemaOptional('exclude'): bool
}], model, _logger)
schema.validate(config_list) schema.validate(config_list)
...@@ -291,13 +307,11 @@ class FPGMPruner(OneShotPruner): ...@@ -291,13 +307,11 @@ class FPGMPruner(OneShotPruner):
self.dummy_input = dummy_input self.dummy_input = dummy_input
super().__init__(model, config_list) super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]): def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{ schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1), for sub_shcema in schema_list:
SchemaOptional('op_types'): ['Conv2d', 'Linear'], sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear']
SchemaOptional('op_names'): [str], schema = CompressorSchema(schema_list, model, _logger)
SchemaOptional('exclude'): bool
}], model, _logger)
schema.validate(config_list) schema.validate(config_list)
...@@ -376,15 +390,15 @@ class SlimPruner(OneShotPruner): ...@@ -376,15 +390,15 @@ class SlimPruner(OneShotPruner):
self._scale = scale self._scale = scale
super().__init__(model, config_list) super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]): def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{ schema_list = [deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1), if self.mode == 'global':
SchemaOptional('total_sparsity'): And(float, lambda n: 0 < n < 1), schema_list.append(deepcopy(GLOBAL_SCHEMA))
SchemaOptional('max_sparsity_per_layer'): And(float, lambda n: 0 < n < 1), else:
SchemaOptional('op_types'): ['BatchNorm2d'], schema_list.append(deepcopy(NORMAL_SCHEMA))
SchemaOptional('op_names'): [str], for sub_shcema in schema_list:
SchemaOptional('exclude'): bool sub_shcema[SchemaOptional('op_types')] = ['BatchNorm2d']
}], model, _logger) schema = CompressorSchema(schema_list, model, _logger)
schema.validate(config_list) schema.validate(config_list)
...@@ -477,13 +491,11 @@ class ActivationPruner(OneShotPruner): ...@@ -477,13 +491,11 @@ class ActivationPruner(OneShotPruner):
self._activation = self._choose_activation(activation) self._activation = self._choose_activation(activation)
super().__init__(model, config_list) super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]): def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{ schema_list = [deepcopy(NORMAL_SCHEMA), deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1), for sub_shcema in schema_list:
SchemaOptional('op_types'): ['Conv2d', 'Linear'], sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear']
SchemaOptional('op_names'): [str], schema = CompressorSchema(schema_list, model, _logger)
SchemaOptional('exclude'): bool
}], model, _logger)
schema.validate(config_list) schema.validate(config_list)
...@@ -603,19 +615,19 @@ class TaylorFOWeightPruner(OneShotPruner): ...@@ -603,19 +615,19 @@ class TaylorFOWeightPruner(OneShotPruner):
self.training_batches = training_batches self.training_batches = training_batches
super().__init__(model, config_list) super().__init__(model, config_list)
def validate_config(self, model: Module, config_list: List[Dict]): def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
schema = PrunerSchema([{ schema_list = [deepcopy(EXCLUDE_SCHEMA), deepcopy(INTERNAL_SCHEMA)]
SchemaOptional('sparsity_per_layer'): And(float, lambda n: 0 < n < 1), if self.mode == 'global':
SchemaOptional('total_sparsity'): And(float, lambda n: 0 < n < 1), schema_list.append(deepcopy(GLOBAL_SCHEMA))
SchemaOptional('max_sparsity_per_layer'): And(float, lambda n: 0 < n < 1), else:
SchemaOptional('op_types'): ['Conv2d', 'Linear'], schema_list.append(deepcopy(NORMAL_SCHEMA))
SchemaOptional('op_names'): [str], for sub_shcema in schema_list:
SchemaOptional('exclude'): bool sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear']
}], model, _logger) schema = CompressorSchema(schema_list, model, _logger)
schema.validate(config_list) schema.validate(config_list)
def _collector(self, buffer: List, weight_tensor: Tensor) -> Callable[[Module, Tensor, Tensor], None]: def _collector(self, buffer: List, weight_tensor: Tensor) -> Callable[[Tensor], None]:
def collect_taylor(grad: Tensor): def collect_taylor(grad: Tensor):
if len(buffer) < self.training_batches: if len(buffer) < self.training_batches:
buffer.append(self._calculate_taylor_expansion(weight_tensor, grad)) buffer.append(self._calculate_taylor_expansion(weight_tensor, grad))
......
...@@ -437,6 +437,6 @@ class SparsityAllocator: ...@@ -437,6 +437,6 @@ class SparsityAllocator:
mask = mask.unfold(i, step, step) mask = mask.unfold(i, step, step)
ein_expression += lower_case_letters[i] ein_expression += lower_case_letters[i]
ein_expression = '...{},{}'.format(ein_expression, ein_expression) ein_expression = '...{},{}'.format(ein_expression, ein_expression)
mask = torch.einsum(ein_expression, mask, torch.ones(self.block_sparse_size)) mask = torch.einsum(ein_expression, mask, torch.ones(self.block_sparse_size).to(mask.device))
return (mask != 0).type_as(mask) return (mask != 0).type_as(mask)
...@@ -120,7 +120,7 @@ class DistMetricsCalculator(MetricsCalculator): ...@@ -120,7 +120,7 @@ class DistMetricsCalculator(MetricsCalculator):
metric = torch.ones(*reorder_tensor.size()[:len(keeped_dim)], device=reorder_tensor.device) metric = torch.ones(*reorder_tensor.size()[:len(keeped_dim)], device=reorder_tensor.device)
across_dim = list(range(len(keeped_dim), len(reorder_dim))) across_dim = list(range(len(keeped_dim), len(reorder_dim)))
idxs = metric.nonzero() idxs = metric.nonzero(as_tuple=False)
for idx in idxs: for idx in idxs:
other = reorder_tensor other = reorder_tensor
for i in idx: for i in idx:
...@@ -161,7 +161,7 @@ class APoZRankMetricsCalculator(MetricsCalculator): ...@@ -161,7 +161,7 @@ class APoZRankMetricsCalculator(MetricsCalculator):
for dim, dim_size in enumerate(_eq_zero.size()): for dim, dim_size in enumerate(_eq_zero.size()):
if dim not in keeped_dim: if dim not in keeped_dim:
total_size *= dim_size total_size *= dim_size
_apoz = torch.sum(_eq_zero, dim=across_dim, dtype=torch.float64) / total_size _apoz = torch.sum(_eq_zero, dim=across_dim).type_as(activations) / total_size
# NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned. # NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned.
metrics[name] = torch.ones_like(_apoz) - _apoz metrics[name] = torch.ones_like(_apoz) - _apoz
return metrics return metrics
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import math
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
import numpy as np import numpy as np
...@@ -20,7 +21,7 @@ class NormalSparsityAllocator(SparsityAllocator): ...@@ -20,7 +21,7 @@ class NormalSparsityAllocator(SparsityAllocator):
def generate_sparsity(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]: def generate_sparsity(self, metrics: Dict[str, Tensor]) -> Dict[str, Dict[str, Tensor]]:
masks = {} masks = {}
for name, wrapper in self.pruner.get_modules_wrapper().items(): for name, wrapper in self.pruner.get_modules_wrapper().items():
sparsity_rate = wrapper.config['sparsity_per_layer'] sparsity_rate = wrapper.config['total_sparsity']
assert name in metrics, 'Metric of %s is not calculated.' assert name in metrics, 'Metric of %s is not calculated.'
metric = metrics[name] * self._compress_mask(wrapper.weight_mask) metric = metrics[name] * self._compress_mask(wrapper.weight_mask)
...@@ -58,27 +59,31 @@ class GlobalSparsityAllocator(SparsityAllocator): ...@@ -58,27 +59,31 @@ class GlobalSparsityAllocator(SparsityAllocator):
temp_wrapper_config = self.pruner.get_modules_wrapper()[list(group_metric_dict.keys())[0]].config temp_wrapper_config = self.pruner.get_modules_wrapper()[list(group_metric_dict.keys())[0]].config
total_sparsity = temp_wrapper_config['total_sparsity'] total_sparsity = temp_wrapper_config['total_sparsity']
max_sparsity_per_layer = temp_wrapper_config.get('max_sparsity_per_layer', 1.0) max_sparsity_per_layer = temp_wrapper_config.get('max_sparsity_per_layer', {})
for name, metric in group_metric_dict.items(): for name, metric in group_metric_dict.items():
wrapper = self.pruner.get_modules_wrapper()[name] wrapper = self.pruner.get_modules_wrapper()[name]
metric = metric * self._compress_mask(wrapper.weight_mask) metric = metric * self._compress_mask(wrapper.weight_mask)
print(metric)
layer_weight_num = wrapper.module.weight.data.numel() layer_weight_num = wrapper.module.weight.data.numel()
stay_num = int(metric.numel() * max_sparsity_per_layer)
retention_ratio = 1 - max_sparsity_per_layer.get(name, 1)
retention_numel = math.ceil(retention_ratio * layer_weight_num)
removed_metric_num = math.ceil(retention_numel / (wrapper.weight_mask.numel() / metric.numel()))
stay_metric_num = metric.numel() - removed_metric_num
# Remove the weight parts that must be left # Remove the weight parts that must be left
stay_metric = torch.topk(metric.view(-1), stay_num, largest=False)[0] stay_metric = torch.topk(metric.view(-1), stay_metric_num, largest=False)[0]
sub_thresholds[name] = stay_metric.max() sub_thresholds[name] = stay_metric.max()
expend_times = int(layer_weight_num / metric.numel()) expend_times = int(layer_weight_num / metric.numel())
if expend_times > 1: if expend_times > 1:
stay_metric = stay_metric.expand(stay_num, int(layer_weight_num / metric.numel())).view(-1) stay_metric = stay_metric.expand(stay_metric_num, int(layer_weight_num / metric.numel())).view(-1)
metric_list.append(stay_metric) metric_list.append(stay_metric)
total_weight_num += layer_weight_num total_weight_num += layer_weight_num
assert total_sparsity <= max_sparsity_per_layer, 'total_sparsity should less than max_sparsity_per_layer.'
total_prune_num = int(total_sparsity * total_weight_num) total_prune_num = int(total_sparsity * total_weight_num)
if total_prune_num == 0:
threshold = torch.topk(torch.cat(metric_list).view(-1), total_prune_num, largest=False)[0].max().item() threshold = torch.cat(metric_list).min().item() - 1
else:
threshold = torch.topk(torch.cat(metric_list).view(-1), total_prune_num, largest=False)[0].max().item()
return threshold, sub_thresholds return threshold, sub_thresholds
...@@ -108,7 +113,7 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator): ...@@ -108,7 +113,7 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
for _, group_metric_dict in grouped_metrics.items(): for _, group_metric_dict in grouped_metrics.items():
group_metric = self._group_metric_calculate(group_metric_dict) group_metric = self._group_metric_calculate(group_metric_dict)
sparsities = {name: self.pruner.get_modules_wrapper()[name].config['sparsity_per_layer'] for name in group_metric_dict.keys()} sparsities = {name: self.pruner.get_modules_wrapper()[name].config['total_sparsity'] for name in group_metric_dict.keys()}
min_sparsity = min(sparsities.values()) min_sparsity = min(sparsities.values())
conv2d_groups = [self.group_depen[name] for name in group_metric_dict.keys()] conv2d_groups = [self.group_depen[name] for name in group_metric_dict.keys()]
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from logging import Logger
from typing import Dict, List
from schema import Schema, And, SchemaError from schema import Schema, And, SchemaError
from torch.nn import Module
class CompressorSchema:
def __init__(self, data_schema: List[Dict], model: Module, logger: Logger):
assert isinstance(data_schema, list)
self.data_schema = data_schema
self.compressor_schema = Schema(self._modify_schema(data_schema, model, logger))
def _modify_schema(self, data_schema: List[Dict], model: Module, logger: Logger) -> List[Dict]:
if not data_schema:
return data_schema
for i, sub_schema in enumerate(data_schema):
for k, old_schema in sub_schema.items():
if k == 'op_types' or (isinstance(k, Schema) and k._schema == 'op_types'):
new_schema = And(old_schema, lambda n: validate_op_types(model, n, logger))
sub_schema[k] = new_schema
if k == 'op_names' or (isinstance(k, Schema) and k._schema == 'op_names'):
new_schema = And(old_schema, lambda n: validate_op_names(model, n, logger))
sub_schema[k] = new_schema
data_schema[i] = And(sub_schema, lambda d: validate_op_types_op_names(d))
return data_schema
def validate(self, data):
self.compressor_schema.validate(data)
def validate_op_names(model, op_names, logger): def validate_op_names(model, op_names, logger):
found_names = set(map(lambda x: x[0], model.named_modules())) found_names = set(map(lambda x: x[0], model.named_modules()))
...@@ -12,6 +44,7 @@ def validate_op_names(model, op_names, logger): ...@@ -12,6 +44,7 @@ def validate_op_names(model, op_names, logger):
return True return True
def validate_op_types(model, op_types, logger): def validate_op_types(model, op_types, logger):
found_types = set(['default']) | set(map(lambda x: type(x[1]).__name__, model.named_modules())) found_types = set(['default']) | set(map(lambda x: type(x[1]).__name__, model.named_modules()))
...@@ -21,55 +54,8 @@ def validate_op_types(model, op_types, logger): ...@@ -21,55 +54,8 @@ def validate_op_types(model, op_types, logger):
return True return True
def validate_op_types_op_names(data): def validate_op_types_op_names(data):
if not ('op_types' in data or 'op_names' in data): if not ('op_types' in data or 'op_names' in data):
raise SchemaError('Either op_types or op_names must be specified.') raise SchemaError('Either op_types or op_names must be specified.')
return True return True
class CompressorSchema:
def __init__(self, data_schema, model, logger):
assert isinstance(data_schema, list) and len(data_schema) <= 1
self.data_schema = data_schema
self.compressor_schema = Schema(self._modify_schema(data_schema, model, logger))
def _modify_schema(self, data_schema, model, logger):
if not data_schema:
return data_schema
for k in data_schema[0]:
old_schema = data_schema[0][k]
if k == 'op_types' or (isinstance(k, Schema) and k._schema == 'op_types'):
new_schema = And(old_schema, lambda n: validate_op_types(model, n, logger))
data_schema[0][k] = new_schema
if k == 'op_names' or (isinstance(k, Schema) and k._schema == 'op_names'):
new_schema = And(old_schema, lambda n: validate_op_names(model, n, logger))
data_schema[0][k] = new_schema
data_schema[0] = And(data_schema[0], lambda d: validate_op_types_op_names(d))
return data_schema
def validate(self, data):
self.compressor_schema.validate(data)
def validate_exclude_sparsity(data):
if not ('exclude' in data or 'sparsity_per_layer' in data or 'total_sparsity' in data):
raise SchemaError('One of [sparsity_per_layer, total_sparsity, exclude] should be specified.')
return True
def validate_exclude_quant_types_quant_bits(data):
if not ('exclude' in data or ('quant_types' in data and 'quant_bits' in data)):
raise SchemaError('Either (quant_types and quant_bits) or exclude must be specified.')
return True
class PrunerSchema(CompressorSchema):
def _modify_schema(self, data_schema, model, logger):
data_schema = super()._modify_schema(data_schema, model, logger)
data_schema[0] = And(data_schema[0], lambda d: validate_exclude_sparsity(d))
return data_schema
class QuantizerSchema(CompressorSchema):
def _modify_schema(self, data_schema, model, logger):
data_schema = super()._modify_schema(data_schema, model, logger)
data_schema[0] = And(data_schema[0], lambda d: validate_exclude_quant_types_quant_bits(d))
return data_schema
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from copy import deepcopy
from typing import Dict, List
from torch.nn import Module
def config_list_canonical(model: Module, config_list: List[Dict]) -> List[Dict]:
'''
Split the config by op_names if 'sparsity' or 'sparsity_per_layer' in config,
and set the sub_config['total_sparsity'] = config['sparsity_per_layer'].
'''
for config in config_list:
if 'sparsity' in config:
if 'sparsity_per_layer' in config:
raise ValueError("'sparsity' and 'sparsity_per_layer' have the same semantics, can not set both in one config.")
else:
config['sparsity_per_layer'] = config.pop('sparsity')
config_list = dedupe_config_list(unfold_config_list(model, config_list))
new_config_list = []
for config in config_list:
if 'sparsity_per_layer' in config:
sparsity_per_layer = config.pop('sparsity_per_layer')
op_names = config.pop('op_names')
for op_name in op_names:
sub_config = deepcopy(config)
sub_config['op_names'] = [op_name]
sub_config['total_sparsity'] = sparsity_per_layer
new_config_list.append(sub_config)
elif 'max_sparsity_per_layer' in config and isinstance(config['max_sparsity_per_layer'], float):
op_names = config.get('op_names', [])
max_sparsity_per_layer = {}
max_sparsity = config['max_sparsity_per_layer']
for op_name in op_names:
max_sparsity_per_layer[op_name] = max_sparsity
config['max_sparsity_per_layer'] = max_sparsity_per_layer
new_config_list.append(config)
else:
new_config_list.append(config)
return new_config_list
def unfold_config_list(model: Module, config_list: List[Dict]) -> List[Dict]:
'''
Unfold config_list to op_names level.
'''
unfolded_config_list = []
for config in config_list:
op_names = []
for module_name, module in model.named_modules():
module_type = type(module).__name__
if 'op_types' in config and module_type not in config['op_types']:
continue
if 'op_names' in config and module_name not in config['op_names']:
continue
op_names.append(module_name)
unfolded_config = deepcopy(config)
unfolded_config['op_names'] = op_names
unfolded_config_list.append(unfolded_config)
return unfolded_config_list
def dedupe_config_list(config_list: List[Dict]) -> List[Dict]:
'''
Dedupe the op_names in unfolded config_list.
'''
exclude = set()
exclude_idxes = []
config_list = deepcopy(config_list)
for idx, config in reversed(list(enumerate(config_list))):
if 'exclude' in config:
exclude.update(config['op_names'])
exclude_idxes.append(idx)
continue
config['op_names'] = sorted(list(set(config['op_names']).difference(exclude)))
exclude.update(config['op_names'])
for idx in sorted(exclude_idxes, reverse=True):
config_list.pop(idx)
return config_list
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment