"docs/en_US/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "969f0d99d333f07dc1f7086214762224c7d5cb6a"
Unverified Commit 3f64dbfd authored by Panacea's avatar Panacea Committed by GitHub
Browse files

Support 'op_partial_names' in config_list (#4184)

parent b6894c1e
...@@ -48,20 +48,23 @@ __all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPru ...@@ -48,20 +48,23 @@ __all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPru
NORMAL_SCHEMA = { NORMAL_SCHEMA = {
Or('sparsity', 'sparsity_per_layer'): And(float, lambda n: 0 <= n < 1), Or('sparsity', 'sparsity_per_layer'): And(float, lambda n: 0 <= n < 1),
SchemaOptional('op_types'): [str], SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str] SchemaOptional('op_names'): [str],
SchemaOptional('op_partial_names'): [str]
} }
GLOBAL_SCHEMA = { GLOBAL_SCHEMA = {
'total_sparsity': And(float, lambda n: 0 <= n < 1), 'total_sparsity': And(float, lambda n: 0 <= n < 1),
SchemaOptional('max_sparsity_per_layer'): And(float, lambda n: 0 < n <= 1), SchemaOptional('max_sparsity_per_layer'): And(float, lambda n: 0 < n <= 1),
SchemaOptional('op_types'): [str], SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str] SchemaOptional('op_names'): [str],
SchemaOptional('op_partial_names'): [str]
} }
EXCLUDE_SCHEMA = { EXCLUDE_SCHEMA = {
'exclude': bool, 'exclude': bool,
SchemaOptional('op_types'): [str], SchemaOptional('op_types'): [str],
SchemaOptional('op_names'): [str] SchemaOptional('op_names'): [str],
SchemaOptional('op_partial_names'): [str]
} }
INTERNAL_SCHEMA = { INTERNAL_SCHEMA = {
......
...@@ -56,6 +56,7 @@ def validate_op_types(model, op_types, logger): ...@@ -56,6 +56,7 @@ def validate_op_types(model, op_types, logger):
def validate_op_types_op_names(data): def validate_op_types_op_names(data):
if not ('op_types' in data or 'op_names' in data): if not ('op_types' in data or 'op_names' in data or 'op_partial_names' in data):
raise SchemaError('Either op_types or op_names must be specified.') raise SchemaError('At least one of the followings must be specified: op_types, op_names or op_partial_names.')
return True return True
...@@ -21,6 +21,20 @@ def config_list_canonical(model: Module, config_list: List[Dict]) -> List[Dict]: ...@@ -21,6 +21,20 @@ def config_list_canonical(model: Module, config_list: List[Dict]) -> List[Dict]:
else: else:
config['sparsity_per_layer'] = config.pop('sparsity') config['sparsity_per_layer'] = config.pop('sparsity')
for config in config_list:
if 'op_partial_names' in config:
op_names = []
for partial_name in config['op_partial_names']:
for name, _ in model.named_modules():
if partial_name in name:
op_names.append(name)
if 'op_names' in config:
config['op_names'].extend(op_names)
config['op_names'] = list(set(config['op_names']))
else:
config['op_names'] = op_names
config.pop('op_partial_names')
config_list = dedupe_config_list(unfold_config_list(model, config_list)) config_list = dedupe_config_list(unfold_config_list(model, config_list))
new_config_list = [] new_config_list = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment