"...git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "86aaf3c11385a88826e2c28ff8edbf711750301d"
Unverified Commit d1bc0cfc authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Add model compression config validation (#2219)

parent e0b692c9
......@@ -3,6 +3,8 @@
import logging
import torch
from schema import And, Optional
from .utils import CompressorSchema
from .compressor import Pruner
__all__ = ['ActivationAPoZRankFilterPruner', 'ActivationMeanRankFilterPruner']
......@@ -50,6 +52,24 @@ class ActivationRankFilterPruner(Pruner):
else:
self.activation = None
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def get_mask(self, base_mask, activations, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
......
......@@ -40,6 +40,9 @@ class Compressor:
optimizer: pytorch optimizer
optimizer used to train the model
"""
assert isinstance(model, torch.nn.Module)
self.validate_config(model, config_list)
self.bound_model = model
self.config_list = config_list
self.optimizer = optimizer
......@@ -54,9 +57,17 @@ class Compressor:
for layer, config in self._detect_modules_to_compress():
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
if not self.modules_wrapper:
_logger.warning('Nothing is configured to compress, please check your model and config_list')
self._wrap_model()
def validate_config(self, model, config_list):
"""
subclass can optionally implement this method to check if config_list if valid
"""
pass
def _detect_modules_to_compress(self):
"""
detect all modules should be compressed, and save the result in `self.modules_to_compress`.
......@@ -65,6 +76,8 @@ class Compressor:
if self.modules_to_compress is None:
self.modules_to_compress = []
for name, module in self.bound_model.named_modules():
if module == self.bound_model:
continue
layer = LayerInfo(name, module)
config = self.select_config(layer)
if config is not None:
......
......@@ -4,7 +4,9 @@
import copy
import logging
import torch
from schema import And, Optional
from .compressor import Pruner
from .utils import CompressorSchema
__all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'LotteryTicketPruner']
......@@ -31,6 +33,23 @@ class LevelPruner(Pruner):
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer
......@@ -90,6 +109,27 @@ class AGP_Pruner(Pruner):
self.now_epoch = 0
self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'initial_sparsity': And(float, lambda n: 0 <= n <= 1),
'final_sparsity': And(float, lambda n: 0 <= n <= 1),
'start_epoch': And(int, lambda n: n >= 0),
'end_epoch': And(int, lambda n: n >= 0),
'frequency': And(int, lambda n: n > 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
......@@ -208,6 +248,24 @@ class SlimPruner(Pruner):
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'op_types': ['BatchNorm2d'],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def calc_mask(self, wrapper, **kwargs):
"""
Calculate the mask of given layer.
......@@ -273,7 +331,7 @@ class LotteryTicketPruner(Pruner):
"""
super().__init__(model, config_list, optimizer)
self.curr_prune_iteration = None
self.prune_iterations = self._validate_config(config_list)
self.prune_iterations = config_list[0]['prune_iterations']
# save init weights and optimizer
self.reset_weights = reset_weights
......@@ -286,16 +344,26 @@ class LotteryTicketPruner(Pruner):
if lr_scheduler is not None:
self._scheduler_state = copy.deepcopy(lr_scheduler.state_dict())
def _validate_config(self, config_list):
prune_iterations = None
for config in config_list:
assert 'prune_iterations' in config, 'prune_iterations must exist in your config'
assert 'sparsity' in config, 'sparsity must exist in your config'
if prune_iterations is not None:
assert prune_iterations == config[
'prune_iterations'], 'The values of prune_iterations must be equal in your config'
prune_iterations = config['prune_iterations']
return prune_iterations
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
Supported keys:
- prune_iterations : The number of rounds for the iterative pruning.
- sparsity : The final sparsity when the compression is done.
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
'prune_iterations': And(int, lambda n: n > 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
assert len(set([x['prune_iterations'] for x in config_list])) == 1, 'The values of prune_iterations must be equal in your config'
def _calc_sparsity(self, sparsity):
keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations)
......
......@@ -3,6 +3,8 @@
import logging
import torch
from schema import Schema, And, Or, Optional
from .utils import CompressorSchema
from .compressor import Quantizer, QuantGrad, QuantType
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']
......@@ -17,6 +19,16 @@ class NaiveQuantizer(Quantizer):
super().__init__(model, config_list, optimizer)
self.layer_scale = {}
def validate_config(self, model, config_list):
schema = CompressorSchema([{
Optional('quant_types'): ['weight'],
Optional('quant_bits'): Or(8, {'weight': 8}),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def quantize_weight(self, weight, wrapper, **kwargs):
new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
......@@ -137,6 +149,28 @@ class QAT_Quantizer(Quantizer):
layer.module.register_buffer('tracked_max_biased', torch.zeros(1))
layer.module.register_buffer('tracked_max', torch.zeros(1))
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32),
Optional('output'): And(int, lambda n: 0 < n < 32),
})),
Optional('quant_start_step'): And(int, lambda n: n >= 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def _quantize(self, bits, op, real_val):
"""
quantize real value.
......@@ -233,6 +267,26 @@ class DoReFaQuantizer(Quantizer):
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32)
})),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def quantize_weight(self, weight, wrapper, **kwargs):
weight_bits = get_bits_length(wrapper.config, 'weight')
out = weight.tanh()
......@@ -264,6 +318,27 @@ class BNNQuantizer(Quantizer):
super().__init__(model, config_list, optimizer)
self.quant_grad = ClipGrad
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32),
Optional('output'): And(int, lambda n: 0 < n < 32),
})),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def quantize_weight(self, weight, wrapper, **kwargs):
out = torch.sign(weight)
# remove zeros
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from schema import Schema, And, SchemaError
def validate_op_names(model, op_names, logger):
found_names = set(map(lambda x: x[0], model.named_modules()))
not_found_op_names = list(set(op_names) - found_names)
if not_found_op_names:
logger.warning('op_names %s not found in model', not_found_op_names)
return True
def validate_op_types(model, op_types, logger):
found_types = set(['default']) | set(map(lambda x: type(x[1]).__name__, model.named_modules()))
not_found_op_types = list(set(op_types) - found_types)
if not_found_op_types:
logger.warning('op_types %s not found in model', not_found_op_types)
return True
def validate_op_types_op_names(data):
if not ('op_types' in data or 'op_names' in data):
raise SchemaError('Either op_types or op_names must be specified.')
return True
class CompressorSchema:
def __init__(self, data_schema, model, logger):
assert isinstance(data_schema, list) and len(data_schema) <= 1
self.data_schema = data_schema
self.compressor_schema = Schema(self._modify_schema(data_schema, model, logger))
def _modify_schema(self, data_schema, model, logger):
if not data_schema:
return data_schema
for k in data_schema[0]:
old_schema = data_schema[0][k]
if k == 'op_types' or (isinstance(k, Schema) and k._schema == 'op_types'):
new_schema = And(old_schema, lambda n: validate_op_types(model, n, logger))
data_schema[0][k] = new_schema
if k == 'op_names' or (isinstance(k, Schema) and k._schema == 'op_names'):
new_schema = And(old_schema, lambda n: validate_op_names(model, n, logger))
data_schema[0][k] = new_schema
data_schema[0] = And(data_schema[0], lambda d: validate_op_types_op_names(d))
return data_schema
def validate(self, data):
self.compressor_schema.validate(data)
......@@ -3,6 +3,8 @@
import logging
import torch
from schema import And, Optional
from .utils import CompressorSchema
from .compressor import Pruner
__all__ = ['L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
......@@ -31,6 +33,24 @@ class WeightRankFilterPruner(Pruner):
super().__init__(model, config_list, optimizer)
self.set_wrappers_attribute("if_calculated", False)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): ['Conv2d'],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def get_mask(self, base_mask, weight, num_prune):
raise NotImplementedError('{} get_mask is not implemented'.format(self.__class__.__name__))
......
......@@ -6,6 +6,7 @@ import numpy as np
import tensorflow as tf
import torch
import torch.nn.functional as F
import schema
import nni.compression.torch as torch_compressor
import math
......@@ -267,6 +268,79 @@ class CompressorTestCase(TestCase):
assert math.isclose(model.relu.module.tracked_min_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.module.tracked_max_biased, 0.00998, abs_tol=eps)
def test_torch_pruner_validation(self):
# test bad configuraiton
pruner_classes = [torch_compressor.__dict__[x] for x in \
['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', 'AGP_Pruner', \
'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']]
bad_configs = [
[
{'sparsity': '0.2'},
{'sparsity': 0.6 }
],
[
{'sparsity': 0.2},
{'sparsity': 1.6 }
],
[
{'sparsity': 0.2, 'op_types': 'default'},
{'sparsity': 0.6 }
],
[
{'sparsity': 0.2 },
{'sparsity': 0.6, 'op_names': 'abc' }
]
]
model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for pruner_class in pruner_classes:
for config_list in bad_configs:
try:
pruner_class(model, config_list, optimizer)
print(config_list)
assert False, 'Validation error should be raised for bad configuration'
except schema.SchemaError:
pass
except:
print('FAILED:', pruner_class, config_list)
raise
def test_torch_quantizer_validation(self):
# test bad configuraiton
quantizer_classes = [torch_compressor.__dict__[x] for x in \
['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']]
bad_configs = [
[
{'bad_key': 'abc'}
],
[
{'quant_types': 'abc'}
],
[
{'quant_bits': 34}
],
[
{'op_types': 'default'}
],
[
{'quant_bits': {'abc': 123}}
]
]
model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for quantizer_class in quantizer_classes:
for config_list in bad_configs:
try:
quantizer_class(model, config_list, optimizer)
print(config_list)
assert False, 'Validation error should be raised for bad configuration'
except schema.SchemaError:
pass
except:
print('FAILED:', quantizer_class, config_list)
raise
if __name__ == '__main__':
main()
......@@ -34,7 +34,7 @@ prune_config = {
'agp': {
'pruner_class': AGP_Pruner,
'config_list': [{
'initial_sparsity': 0,
'initial_sparsity': 0.,
'final_sparsity': 0.8,
'start_epoch': 0,
'end_epoch': 10,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment