Unverified Commit 3f64dbfd authored by Panacea's avatar Panacea Committed by GitHub
Browse files

Support 'op_partial_names' in config_list (#4184)

parent b6894c1e
......@@ -48,20 +48,23 @@ __all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPru
NORMAL_SCHEMA = {
Or('sparsity', 'sparsity_per_layer'): And(float, lambda n: 0 <= n < 1),
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str]
SchemaOptional('op_names'): [str],
SchemaOptional('op_partial_names'): [str]
}
GLOBAL_SCHEMA = {
'total_sparsity': And(float, lambda n: 0 <= n < 1),
SchemaOptional('max_sparsity_per_layer'): And(float, lambda n: 0 < n <= 1),
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str]
SchemaOptional('op_names'): [str],
SchemaOptional('op_partial_names'): [str]
}
EXCLUDE_SCHEMA = {
'exclude': bool,
SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str]
SchemaOptional('op_names'): [str],
SchemaOptional('op_partial_names'): [str]
}
INTERNAL_SCHEMA = {
......
......@@ -56,6 +56,7 @@ def validate_op_types(model, op_types, logger):
def validate_op_types_op_names(data):
if not ('op_types' in data or 'op_names' in data):
raise SchemaError('Either op_types or op_names must be specified.')
if not ('op_types' in data or 'op_names' in data or 'op_partial_names' in data):
raise SchemaError('At least one of the followings must be specified: op_types, op_names or op_partial_names.')
return True
......@@ -21,6 +21,20 @@ def config_list_canonical(model: Module, config_list: List[Dict]) -> List[Dict]:
else:
config['sparsity_per_layer'] = config.pop('sparsity')
for config in config_list:
if 'op_partial_names' in config:
op_names = []
for partial_name in config['op_partial_names']:
for name, _ in model.named_modules():
if partial_name in name:
op_names.append(name)
if 'op_names' in config:
config['op_names'].extend(op_names)
config['op_names'] = list(set(config['op_names']))
else:
config['op_names'] = op_names
config.pop('op_partial_names')
config_list = dedupe_config_list(unfold_config_list(model, config_list))
new_config_list = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment