Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
3f64dbfd
Unverified
Commit
3f64dbfd
authored
Sep 27, 2021
by
Panacea
Committed by
GitHub
Sep 27, 2021
Browse files
Support 'op_partial_names' in config_list (#4184)
parent
b6894c1e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
5 deletions
+23
-5
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+6
-3
nni/algorithms/compression/v2/pytorch/utils/config_validation.py
...orithms/compression/v2/pytorch/utils/config_validation.py
+3
-2
nni/algorithms/compression/v2/pytorch/utils/pruning.py
nni/algorithms/compression/v2/pytorch/utils/pruning.py
+14
-0
No files found.
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
3f64dbfd
...
...
@@ -48,20 +48,23 @@ __all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPru
NORMAL_SCHEMA
=
{
Or
(
'sparsity'
,
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<=
n
<
1
),
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_names'
):
[
str
]
SchemaOptional
(
'op_names'
):
[
str
],
SchemaOptional
(
'op_partial_names'
):
[
str
]
}
GLOBAL_SCHEMA
=
{
'total_sparsity'
:
And
(
float
,
lambda
n
:
0
<=
n
<
1
),
SchemaOptional
(
'max_sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<=
1
),
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_names'
):
[
str
]
SchemaOptional
(
'op_names'
):
[
str
],
SchemaOptional
(
'op_partial_names'
):
[
str
]
}
EXCLUDE_SCHEMA
=
{
'exclude'
:
bool
,
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_names'
):
[
str
]
SchemaOptional
(
'op_names'
):
[
str
],
SchemaOptional
(
'op_partial_names'
):
[
str
]
}
INTERNAL_SCHEMA
=
{
...
...
nni/algorithms/compression/v2/pytorch/utils/config_validation.py
View file @
3f64dbfd
...
...
@@ -56,6 +56,7 @@ def validate_op_types(model, op_types, logger):
def
validate_op_types_op_names
(
data
):
if
not
(
'op_types'
in
data
or
'op_names'
in
data
):
raise
SchemaError
(
'Either op_types or op_names must be specified.'
)
if
not
(
'op_types'
in
data
or
'op_names'
in
data
or
'op_partial_names'
in
data
):
raise
SchemaError
(
'At least one of the followings must be specified: op_types, op_names or op_partial_names.'
)
return
True
nni/algorithms/compression/v2/pytorch/utils/pruning.py
View file @
3f64dbfd
...
...
@@ -21,6 +21,20 @@ def config_list_canonical(model: Module, config_list: List[Dict]) -> List[Dict]:
else
:
config
[
'sparsity_per_layer'
]
=
config
.
pop
(
'sparsity'
)
for
config
in
config_list
:
if
'op_partial_names'
in
config
:
op_names
=
[]
for
partial_name
in
config
[
'op_partial_names'
]:
for
name
,
_
in
model
.
named_modules
():
if
partial_name
in
name
:
op_names
.
append
(
name
)
if
'op_names'
in
config
:
config
[
'op_names'
].
extend
(
op_names
)
config
[
'op_names'
]
=
list
(
set
(
config
[
'op_names'
]))
else
:
config
[
'op_names'
]
=
op_names
config
.
pop
(
'op_partial_names'
)
config_list
=
dedupe_config_list
(
unfold_config_list
(
model
,
config_list
))
new_config_list
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment