Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
3f64dbfd
"docs/en_US/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "969f0d99d333f07dc1f7086214762224c7d5cb6a"
Unverified
Commit
3f64dbfd
authored
Sep 27, 2021
by
Panacea
Committed by
GitHub
Sep 27, 2021
Browse files
Support 'op_partial_names' in config_list (#4184)
parent
b6894c1e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
5 deletions
+23
-5
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+6
-3
nni/algorithms/compression/v2/pytorch/utils/config_validation.py
...orithms/compression/v2/pytorch/utils/config_validation.py
+3
-2
nni/algorithms/compression/v2/pytorch/utils/pruning.py
nni/algorithms/compression/v2/pytorch/utils/pruning.py
+14
-0
No files found.
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
3f64dbfd
...
@@ -48,20 +48,23 @@ __all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPru
...
@@ -48,20 +48,23 @@ __all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPru
NORMAL_SCHEMA
=
{
NORMAL_SCHEMA
=
{
Or
(
'sparsity'
,
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<=
n
<
1
),
Or
(
'sparsity'
,
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<=
n
<
1
),
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_names'
):
[
str
]
SchemaOptional
(
'op_names'
):
[
str
],
SchemaOptional
(
'op_partial_names'
):
[
str
]
}
}
GLOBAL_SCHEMA
=
{
GLOBAL_SCHEMA
=
{
'total_sparsity'
:
And
(
float
,
lambda
n
:
0
<=
n
<
1
),
'total_sparsity'
:
And
(
float
,
lambda
n
:
0
<=
n
<
1
),
SchemaOptional
(
'max_sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<=
1
),
SchemaOptional
(
'max_sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<
n
<=
1
),
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_names'
):
[
str
]
SchemaOptional
(
'op_names'
):
[
str
],
SchemaOptional
(
'op_partial_names'
):
[
str
]
}
}
EXCLUDE_SCHEMA
=
{
EXCLUDE_SCHEMA
=
{
'exclude'
:
bool
,
'exclude'
:
bool
,
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_types'
):
[
str
],
SchemaOptional
(
'op_names'
):
[
str
]
SchemaOptional
(
'op_names'
):
[
str
],
SchemaOptional
(
'op_partial_names'
):
[
str
]
}
}
INTERNAL_SCHEMA
=
{
INTERNAL_SCHEMA
=
{
...
...
nni/algorithms/compression/v2/pytorch/utils/config_validation.py
View file @
3f64dbfd
...
@@ -56,6 +56,7 @@ def validate_op_types(model, op_types, logger):
...
@@ -56,6 +56,7 @@ def validate_op_types(model, op_types, logger):
def
validate_op_types_op_names
(
data
):
def
validate_op_types_op_names
(
data
):
if
not
(
'op_types'
in
data
or
'op_names'
in
data
):
if
not
(
'op_types'
in
data
or
'op_names'
in
data
or
'op_partial_names'
in
data
):
raise
SchemaError
(
'Either op_types or op_names must be specified.'
)
raise
SchemaError
(
'At least one of the followings must be specified: op_types, op_names or op_partial_names.'
)
return
True
return
True
nni/algorithms/compression/v2/pytorch/utils/pruning.py
View file @
3f64dbfd
...
@@ -21,6 +21,20 @@ def config_list_canonical(model: Module, config_list: List[Dict]) -> List[Dict]:
...
@@ -21,6 +21,20 @@ def config_list_canonical(model: Module, config_list: List[Dict]) -> List[Dict]:
else
:
else
:
config
[
'sparsity_per_layer'
]
=
config
.
pop
(
'sparsity'
)
config
[
'sparsity_per_layer'
]
=
config
.
pop
(
'sparsity'
)
for
config
in
config_list
:
if
'op_partial_names'
in
config
:
op_names
=
[]
for
partial_name
in
config
[
'op_partial_names'
]:
for
name
,
_
in
model
.
named_modules
():
if
partial_name
in
name
:
op_names
.
append
(
name
)
if
'op_names'
in
config
:
config
[
'op_names'
].
extend
(
op_names
)
config
[
'op_names'
]
=
list
(
set
(
config
[
'op_names'
]))
else
:
config
[
'op_names'
]
=
op_names
config
.
pop
(
'op_partial_names'
)
config_list
=
dedupe_config_list
(
unfold_config_list
(
model
,
config_list
))
config_list
=
dedupe_config_list
(
unfold_config_list
(
model
,
config_list
))
new_config_list
=
[]
new_config_list
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment