Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
abb4dfdb
Unverified
Commit
abb4dfdb
authored
Oct 13, 2021
by
Panacea
Committed by
GitHub
Oct 13, 2021
Browse files
Fix v2 level pruner default config bug (#4245)
parent
c9cd53aa
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
11 deletions
+15
-11
nni/algorithms/compression/v2/pytorch/base/compressor.py
nni/algorithms/compression/v2/pytorch/base/compressor.py
+1
-9
nni/algorithms/compression/v2/pytorch/utils/pruning.py
nni/algorithms/compression/v2/pytorch/utils/pruning.py
+14
-2
No files found.
nni/algorithms/compression/v2/pytorch/base/compressor.py
View file @
abb4dfdb
...
@@ -9,7 +9,7 @@ import torch
...
@@ -9,7 +9,7 @@ import torch
from
torch.nn
import
Module
from
torch.nn
import
Module
from
nni.common.graph_utils
import
TorchModuleGraph
from
nni.common.graph_utils
import
TorchModuleGraph
from
nni.algorithms.compression.v2.pytorch.utils
import
get_module_by_name
from
nni.algorithms.compression.v2.pytorch.utils
.pruning
import
get_module_by_name
,
weighted_modules
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -32,14 +32,6 @@ def _setattr(model: Module, name: str, module: Module):
...
@@ -32,14 +32,6 @@ def _setattr(model: Module, name: str, module: Module):
raise
'{} not exist.'
.
format
(
name
)
raise
'{} not exist.'
.
format
(
name
)
weighted_modules
=
[
'Conv1d'
,
'Conv2d'
,
'Conv3d'
,
'ConvTranspose1d'
,
'ConvTranspose2d'
,
'ConvTranspose3d'
,
'Linear'
,
'Bilinear'
,
'PReLU'
,
'Embedding'
,
'EmbeddingBag'
,
]
class
Compressor
:
class
Compressor
:
"""
"""
The abstract base pytorch compressor.
The abstract base pytorch compressor.
...
...
nni/algorithms/compression/v2/pytorch/utils/pruning.py
View file @
abb4dfdb
...
@@ -8,6 +8,13 @@ import torch
...
@@ -8,6 +8,13 @@ import torch
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.nn
import
Module
from
torch.nn
import
Module
weighted_modules
=
[
'Conv1d'
,
'Conv2d'
,
'Conv3d'
,
'ConvTranspose1d'
,
'ConvTranspose2d'
,
'ConvTranspose3d'
,
'Linear'
,
'Bilinear'
,
'PReLU'
,
'Embedding'
,
'EmbeddingBag'
,
]
def
config_list_canonical
(
model
:
Module
,
config_list
:
List
[
Dict
])
->
List
[
Dict
]:
def
config_list_canonical
(
model
:
Module
,
config_list
:
List
[
Dict
])
->
List
[
Dict
]:
'''
'''
...
@@ -37,6 +44,12 @@ def config_list_canonical(model: Module, config_list: List[Dict]) -> List[Dict]:
...
@@ -37,6 +44,12 @@ def config_list_canonical(model: Module, config_list: List[Dict]) -> List[Dict]:
else
:
else
:
config
[
'sparsity_per_layer'
]
=
config
.
pop
(
'sparsity'
)
config
[
'sparsity_per_layer'
]
=
config
.
pop
(
'sparsity'
)
for
config
in
config_list
:
if
'op_types'
in
config
:
if
'default'
in
config
[
'op_types'
]:
config
[
'op_types'
].
remove
(
'default'
)
config
[
'op_types'
].
extend
(
weighted_modules
)
for
config
in
config_list
:
for
config
in
config_list
:
if
'op_partial_names'
in
config
:
if
'op_partial_names'
in
config
:
op_names
=
[]
op_names
=
[]
...
@@ -225,18 +238,17 @@ def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[
...
@@ -225,18 +238,17 @@ def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[
model_weights_numel
[
module_name
]
=
module
.
weight
.
data
.
numel
()
model_weights_numel
[
module_name
]
=
module
.
weight
.
data
.
numel
()
return
model_weights_numel
,
masked_rate
return
model_weights_numel
,
masked_rate
# FIXME: to avoid circular import, copy this function in this place
# FIXME: to avoid circular import, copy this function in this place
def
get_module_by_name
(
model
,
module_name
):
def
get_module_by_name
(
model
,
module_name
):
"""
"""
Get a module specified by its module name
Get a module specified by its module name
Parameters
Parameters
----------
----------
model : pytorch model
model : pytorch model
the pytorch model from which to get its module
the pytorch model from which to get its module
module_name : str
module_name : str
the name of the required module
the name of the required module
Returns
Returns
-------
-------
module, module
module, module
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment