Unverified Commit d1bc0cfc authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Add model compression config validation (#2219)

parent e0b692c9
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
import logging import logging
import torch import torch
from schema import And, Optional
from .utils import CompressorSchema
from .compressor import Pruner from .compressor import Pruner
__all__ = ['ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner'] __all__ = ['ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']
...@@ -50,6 +52,24 @@ class ActivationRankFilterPruner(Pruner): ...@@ -50,6 +52,24 @@ class ActivationRankFilterPruner(Pruner):
else: else:
self.activation = None self.activation = None
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def get_mask(self, base_mask, activations, num_prune): def get_mask(self, base_mask, activations, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
......
...@@ -40,6 +40,9 @@ class Compressor: ...@@ -40,6 +40,9 @@ class Compressor:
optimizer: pytorch optimizer optimizer: pytorch optimizer
optimizer used to train the model optimizer used to train the model
""" """
assert isinstance(model, torch.nn.Module)
self.validate_config(model, config_list)
self.bound_model = model self.bound_model = model
self.config_list = config_list self.config_list = config_list
self.optimizer = optimizer self.optimizer = optimizer
...@@ -54,9 +57,17 @@ class Compressor: ...@@ -54,9 +57,17 @@ class Compressor:
for layer, config in self._detect_modules_to_compress(): for layer, config in self._detect_modules_to_compress():
wrapper = self._wrap_modules(layer, config) wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper) self.modules_wrapper.append(wrapper)
if not self.modules_wrapper:
_logger.warning('Nothing is configured to compress, please check your model and config_list')
self._wrap_model() self._wrap_model()
def validate_config(self, model, config_list):
"""
subclass can optionally implement this method to check if config_list if valid
"""
pass
def _detect_modules_to_compress(self): def _detect_modules_to_compress(self):
""" """
detect all modules should be compressed, and save the result in `self.modules_to_compress`. detect all modules should be compressed, and save the result in `self.modules_to_compress`.
...@@ -65,6 +76,8 @@ class Compressor: ...@@ -65,6 +76,8 @@ class Compressor:
if self.modules_to_compress is None: if self.modules_to_compress is None:
self.modules_to_compress = [] self.modules_to_compress = []
for name, module in self.bound_model.named_modules(): for name, module in self.bound_model.named_modules():
if module == self.bound_model:
continue
layer = LayerInfo(name, module) layer = LayerInfo(name, module)
config = self.select_config(layer) config = self.select_config(layer)
if config is not None: if config is not None:
......
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
import copy import copy
import logging import logging
import torch import torch
from schema import And, Optional
from .compressor import Pruner from .compressor import Pruner
from .utils import CompressorSchema
__all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'LotteryTicketPruner'] __all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'LotteryTicketPruner']
...@@ -31,6 +33,23 @@ class LevelPruner(Pruner): ...@@ -31,6 +33,23 @@ class LevelPruner(Pruner):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False) self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def calc_mask(self, wrapper, **kwargs): def calc_mask(self, wrapper, **kwargs):
""" """
Calculate the mask of given layer Calculate the mask of given layer
...@@ -90,6 +109,27 @@ class AGP_Pruner(Pruner): ...@@ -90,6 +109,27 @@ class AGP_Pruner(Pruner):
self.now_epoch = 0 self.now_epoch = 0
self.set_wrappers_attribute("if_calculated", False) self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'initial_sparsity': And(float, lambda n: 0 <= n <= 1),
'final_sparsity': And(float, lambda n: 0 <= n <= 1),
'start_epoch': And(int, lambda n: n >= 0),
'end_epoch': And(int, lambda n: n >= 0),
'frequency': And(int, lambda n: n > 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def calc_mask(self, wrapper, **kwargs): def calc_mask(self, wrapper, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
...@@ -208,6 +248,24 @@ class SlimPruner(Pruner): ...@@ -208,6 +248,24 @@ class SlimPruner(Pruner):
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max() self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.set_wrappers_attribute("if_calculated", False) self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['BatchNorm2d'],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def calc_mask(self, wrapper, **kwargs): def calc_mask(self, wrapper, **kwargs):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
...@@ -273,7 +331,7 @@ class LotteryTicketPruner(Pruner): ...@@ -273,7 +331,7 @@ class LotteryTicketPruner(Pruner):
""" """
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.curr_prune_iteration = None self.curr_prune_iteration = None
self.prune_iterations = self._validate_config(config_list) self.prune_iterations = config_list[0]['prune_iterations']
# save init weights and optimizer # save init weights and optimizer
self.reset_weights = reset_weights self.reset_weights = reset_weights
...@@ -286,16 +344,26 @@ class LotteryTicketPruner(Pruner): ...@@ -286,16 +344,26 @@ class LotteryTicketPruner(Pruner):
if lr_scheduler is not None: if lr_scheduler is not None:
self._scheduler_state = copy.deepcopy(lr_scheduler.state_dict()) self._scheduler_state = copy.deepcopy(lr_scheduler.state_dict())
def _validate_config(self, config_list): def validate_config(self, model, config_list):
prune_iterations = None """
for config in config_list: Parameters
assert 'prune_iterations' in config, 'prune_iterations must exist in your config' ----------
assert 'sparsity' in config, 'sparsity must exist in your config' model : torch.nn.module
if prune_iterations is not None: Model to be pruned
assert prune_iterations == config[ config_list : list
'prune_iterations'], 'The values of prune_iterations must be equal in your config' Supported keys:
prune_iterations = config['prune_iterations'] - prune_iterations : The number of rounds for the iterative pruning.
return prune_iterations - sparsity : The final sparsity when the compression is done.
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'prune_iterations': And(int, lambda n: n > 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
assert len(set([x['prune_iterations'] for x in config_list])) == 1, 'The values of prune_iterations must be equal in your config'
def _calc_sparsity(self, sparsity): def _calc_sparsity(self, sparsity):
keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations) keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations)
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
import logging import logging
import torch import torch
from schema import Schema, And, Or, Optional
from .utils import CompressorSchema
from .compressor import Quantizer, QuantGrad, QuantType from .compressor import Quantizer, QuantGrad, QuantType
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer'] __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']
...@@ -17,6 +19,16 @@ class NaiveQuantizer(Quantizer): ...@@ -17,6 +19,16 @@ class NaiveQuantizer(Quantizer):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.layer_scale = {} self.layer_scale = {}
def validate_config(self, model, config_list):
schema = CompressorSchema([{
Optional('quant_types'): ['weight'],
Optional('quant_bits'): Or(8, {'weight': 8}),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def quantize_weight(self, weight, wrapper, **kwargs): def quantize_weight(self, weight, wrapper, **kwargs):
new_scale = weight.abs().max() / 127 new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale) scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
...@@ -137,6 +149,28 @@ class QAT_Quantizer(Quantizer): ...@@ -137,6 +149,28 @@ class QAT_Quantizer(Quantizer):
layer.module.register_buffer('tracked_max_biased', torch.zeros(1)) layer.module.register_buffer('tracked_max_biased', torch.zeros(1))
layer.module.register_buffer('tracked_max', torch.zeros(1)) layer.module.register_buffer('tracked_max', torch.zeros(1))
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32),
Optional('output'): And(int, lambda n: 0 < n < 32),
})),
Optional('quant_start_step'): And(int, lambda n: n >= 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def _quantize(self, bits, op, real_val): def _quantize(self, bits, op, real_val):
""" """
quantize real value. quantize real value.
...@@ -233,6 +267,26 @@ class DoReFaQuantizer(Quantizer): ...@@ -233,6 +267,26 @@ class DoReFaQuantizer(Quantizer):
def __init__(self, model, config_list, optimizer=None): def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32)
})),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def quantize_weight(self, weight, wrapper, **kwargs): def quantize_weight(self, weight, wrapper, **kwargs):
weight_bits = get_bits_length(wrapper.config, 'weight') weight_bits = get_bits_length(wrapper.config, 'weight')
out = weight.tanh() out = weight.tanh()
...@@ -264,6 +318,27 @@ class BNNQuantizer(Quantizer): ...@@ -264,6 +318,27 @@ class BNNQuantizer(Quantizer):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad self.quant_grad = ClipGrad
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32),
Optional('output'): And(int, lambda n: 0 < n < 32),
})),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def quantize_weight(self, weight, wrapper, **kwargs): def quantize_weight(self, weight, wrapper, **kwargs):
out = torch.sign(weight) out = torch.sign(weight)
# remove zeros # remove zeros
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from schema import Schema, And, SchemaError
def validate_op_names(model, op_names, logger):
found_names = set(map(lambda x: x[0], model.named_modules()))
not_found_op_names = list(set(op_names) - found_names)
if not_found_op_names:
logger.warning('op_names %s not found in model', not_found_op_names)
return True
def validate_op_types(model, op_types, logger):
found_types = set(['default']) | set(map(lambda x: type(x[1]).__name__, model.named_modules()))
not_found_op_types = list(set(op_types) - found_types)
if not_found_op_types:
logger.warning('op_types %s not found in model', not_found_op_types)
return True
def validate_op_types_op_names(data):
if not ('op_types' in data or 'op_names' in data):
raise SchemaError('Either op_types or op_names must be specified.')
return True
class CompressorSchema:
def __init__(self, data_schema, model, logger):
assert isinstance(data_schema, list) and len(data_schema) <= 1
self.data_schema = data_schema
self.compressor_schema = Schema(self._modify_schema(data_schema, model, logger))
def _modify_schema(self, data_schema, model, logger):
if not data_schema:
return data_schema
for k in data_schema[0]:
old_schema = data_schema[0][k]
if k == 'op_types' or (isinstance(k, Schema) and k._schema == 'op_types'):
new_schema = And(old_schema, lambda n: validate_op_types(model, n, logger))
data_schema[0][k] = new_schema
if k == 'op_names' or (isinstance(k, Schema) and k._schema == 'op_names'):
new_schema = And(old_schema, lambda n: validate_op_names(model, n, logger))
data_schema[0][k] = new_schema
data_schema[0] = And(data_schema[0], lambda d: validate_op_types_op_names(d))
return data_schema
def validate(self, data):
self.compressor_schema.validate(data)
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
import logging import logging
import torch import torch
from schema import And, Optional
from .utils import CompressorSchema
from .compressor import Pruner from .compressor import Pruner
__all__ = ['L1FilterPruner', 'L2FilterPruner', 'FPGMPruner'] __all__ = ['L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
...@@ -31,6 +33,24 @@ class WeightRankFilterPruner(Pruner): ...@@ -31,6 +33,24 @@ class WeightRankFilterPruner(Pruner):
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False) self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): ['Conv2d'],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def get_mask(self, base_mask, weight, num_prune): def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__)) raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
......
...@@ -6,6 +6,7 @@ import numpy as np ...@@ -6,6 +6,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import schema
import nni.compression.torch as torch_compressor import nni.compression.torch as torch_compressor
import math import math
...@@ -267,6 +268,79 @@ class CompressorTestCase(TestCase): ...@@ -267,6 +268,79 @@ class CompressorTestCase(TestCase):
assert math.isclose(model.relu.module.tracked_min_biased, 0.002, abs_tol=eps) assert math.isclose(model.relu.module.tracked_min_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_max_biased, 0.00998, abs_tol=eps) assert math.isclose(model.relu.module.tracked_max_biased, 0.00998, abs_tol=eps)
def test_torch_pruner_validation(self):
# test bad configuraiton
pruner_classes = [torch_compressor.__dict__[x] for x in \
['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', 'AGP_Pruner', \
'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']]
bad_configs = [
[
{'sparsity': '0.2'},
{'sparsity': 0.6 }
],
[
{'sparsity': 0.2},
{'sparsity': 1.6 }
],
[
{'sparsity': 0.2, 'op_types': 'default'},
{'sparsity': 0.6 }
],
[
{'sparsity': 0.2 },
{'sparsity': 0.6, 'op_names': 'abc' }
]
]
model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for pruner_class in pruner_classes:
for config_list in bad_configs:
try:
pruner_class(model, config_list, optimizer)
print(config_list)
assert False, 'Validation error should be raised for bad configuration'
except schema.SchemaError:
pass
except:
print('FAILED:', pruner_class, config_list)
raise
def test_torch_quantizer_validation(self):
# test bad configuraiton
quantizer_classes = [torch_compressor.__dict__[x] for x in \
['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']]
bad_configs = [
[
{'bad_key': 'abc'}
],
[
{'quant_types': 'abc'}
],
[
{'quant_bits': 34}
],
[
{'op_types': 'default'}
],
[
{'quant_bits': {'abc': 123}}
]
]
model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for quantizer_class in quantizer_classes:
for config_list in bad_configs:
try:
quantizer_class(model, config_list, optimizer)
print(config_list)
assert False, 'Validation error should be raised for bad configuration'
except schema.SchemaError:
pass
except:
print('FAILED:', quantizer_class, config_list)
raise
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -34,7 +34,7 @@ prune_config = { ...@@ -34,7 +34,7 @@ prune_config = {
'agp': { 'agp': {
'pruner_class': AGP_Pruner, 'pruner_class': AGP_Pruner,
'config_list': [{ 'config_list': [{
'initial_sparsity': 0, 'initial_sparsity': 0.,
'final_sparsity': 0.8, 'final_sparsity': 0.8,
'start_epoch': 0, 'start_epoch': 0,
'end_epoch': 10, 'end_epoch': 10,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment