Unverified Commit 68644f59 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

add exclude config validate in compressor (#3815)

parent fb3c596b
......@@ -243,6 +243,7 @@ def main(args):
# Reproduced result in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
# Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
# If you want to skip some layer, you can use 'exclude' like follow.
if args.pruner == 'slim':
config_list = [{
'sparsity': args.sparsity,
......@@ -252,7 +253,10 @@ def main(args):
config_list = [{
'sparsity': args.sparsity,
'op_types': ['Conv2d'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
'op_names': ['feature.0', 'feature.10', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}, {
'exclude': True,
'op_names': ['feature.10']
}]
pruner = pruner_cls(model, config_list, **kw_args)
......
......@@ -11,7 +11,7 @@ from nni.utils import OptimizeMode
from nni.compression.pytorch import ModelSpeedup
from nni.compression.pytorch.compressor import Pruner
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from .simulated_annealing_pruner import SimulatedAnnealingPruner
from .iterative_pruner import ADMMPruner
......@@ -130,16 +130,18 @@ class AutoCompressPruner(Pruner):
"""
if self._base_algo == 'level':
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)
schema.validate(config_list)
......
......@@ -2,10 +2,10 @@
# Licensed under the MIT license.
import logging
from schema import And, Optional, SchemaError
from schema import And, Optional
from nni.common.graph_utils import TorchModuleGraph
from nni.compression.pytorch.utils.shape_dependency import ChannelDependency, GroupDependency
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from nni.compression.pytorch.compressor import Pruner
from .constants import MASKER_DICT
......@@ -82,7 +82,7 @@ class DependencyAwarePruner(Pruner):
self._dependency_update_mask()
def validate_config(self, model, config_list):
schema = CompressorSchema([{
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): ['Conv2d'],
Optional('op_names'): [str],
......@@ -90,9 +90,6 @@ class DependencyAwarePruner(Pruner):
}], model, logger)
schema.validate(config_list)
for config in config_list:
if 'exclude' not in config and 'sparsity' not in config:
raise SchemaError('Either sparisty or exclude must be specified!')
def _supported_dependency_aware(self):
raise NotImplementedError
......
......@@ -5,7 +5,7 @@ import logging
import copy
import torch
from schema import And, Optional
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from .constants import MASKER_DICT
from .dependency_aware_pruner import DependencyAwarePruner
......@@ -138,10 +138,11 @@ class AGPPruner(IterativePruner):
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 <= n <= 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 <= n <= 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
......@@ -300,16 +301,18 @@ class ADMMPruner(IterativePruner):
"""
if self._base_algo == 'level':
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
......@@ -436,10 +439,11 @@ class SlimPruner(IterativePruner):
self.patch_optimizer_before(self._callback)
def validate_config(self, model, config_list):
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['BatchNorm2d'],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
......
......@@ -5,7 +5,7 @@ import copy
import logging
import torch
from schema import And, Optional
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from nni.compression.pytorch.compressor import Pruner
from .finegrained_pruning_masker import LevelPrunerMasker
......@@ -56,11 +56,12 @@ class LotteryTicketPruner(Pruner):
- prune_iterations : The number of rounds for the iterative pruning.
- sparsity : The final sparsity when the compression is done.
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'prune_iterations': And(int, lambda n: n > 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
......
......@@ -11,7 +11,7 @@ from schema import And, Optional
from nni.utils import OptimizeMode
from nni.compression.pytorch.compressor import Pruner
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from nni.compression.pytorch.utils.num_param_counter import get_total_num_weights
from .constants_pruner import PRUNER_DICT
......@@ -120,16 +120,18 @@ class NetAdaptPruner(Pruner):
"""
if self._base_algo == 'level':
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)
schema.validate(config_list)
......
......@@ -4,7 +4,7 @@
import logging
from schema import And, Optional
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from .dependency_aware_pruner import DependencyAwarePruner
__all__ = ['LevelPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
......@@ -48,10 +48,11 @@ class OneshotPruner(DependencyAwarePruner):
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
......
......@@ -9,7 +9,7 @@ import torch
from schema import And, Optional
from nni.compression.pytorch.compressor import Pruner
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from nni.compression.pytorch.utils.sensitivity_analysis import SensitivityAnalysis
from .constants_pruner import PRUNER_DICT
......@@ -146,16 +146,18 @@ class SensitivityPruner(Pruner):
"""
if self.base_algo == 'level':
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)
elif self.base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)
schema.validate(config_list)
......
......@@ -13,7 +13,7 @@ from schema import And, Optional
from nni.utils import OptimizeMode
from nni.compression.pytorch.compressor import Pruner
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from .constants_pruner import PRUNER_DICT
......@@ -115,16 +115,18 @@ class SimulatedAnnealingPruner(Pruner):
"""
if self._base_algo == 'level':
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)
schema.validate(config_list)
......
......@@ -5,7 +5,7 @@ import logging
import copy
import torch
from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer']
......@@ -22,11 +22,12 @@ class NaiveQuantizer(Quantizer):
self.layer_scale = {}
def validate_config(self, model, config_list):
schema = CompressorSchema([{
schema = QuantizerSchema([{
Optional('quant_types'): ['weight'],
Optional('quant_bits'): Or(8, {'weight': 8}),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
......@@ -183,7 +184,7 @@ class QAT_Quantizer(Quantizer):
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32),
......@@ -191,7 +192,8 @@ class QAT_Quantizer(Quantizer):
})),
Optional('quant_start_step'): And(int, lambda n: n >= 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
......@@ -386,13 +388,14 @@ class DoReFaQuantizer(Quantizer):
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32)
})),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
......@@ -493,14 +496,15 @@ class BNNQuantizer(Quantizer):
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32),
Optional('output'): And(int, lambda n: 0 < n < 32),
})),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
......
......@@ -51,3 +51,25 @@ class CompressorSchema:
def validate(self, data):
self.compressor_schema.validate(data)
def validate_exclude_sparsity(data):
if not ('exclude' in data or 'sparsity' in data):
raise SchemaError('Either sparisty or exclude must be specified.')
return True
def validate_exclude_quant_types_quant_bits(data):
if not ('exclude' in data or ('quant_types' in data and 'quant_bits' in data)):
raise SchemaError('Either (quant_types and quant_bits) or exclude must be specified.')
return True
class PrunerSchema(CompressorSchema):
def _modify_schema(self, data_schema, model, logger):
data_schema = super()._modify_schema(data_schema, model, logger)
data_schema[0] = And(data_schema[0], lambda d: validate_exclude_sparsity(d))
return data_schema
class QuantizerSchema(CompressorSchema):
def _modify_schema(self, data_schema, model, logger):
data_schema = super()._modify_schema(data_schema, model, logger)
data_schema[0] = And(data_schema[0], lambda d: validate_exclude_quant_types_quant_bits(d))
return data_schema
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment