Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
abb4dfdb
"driver/driver.hip.cpp" did not exist on "1de6fd07535833877019634a95eafd329406be4c"
Unverified
Commit
abb4dfdb
authored
Oct 13, 2021
by
Panacea
Committed by
GitHub
Oct 13, 2021
Browse files
Fix v2 level pruner default config bug (#4245)
parent
c9cd53aa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
11 deletions
+15
-11
nni/algorithms/compression/v2/pytorch/base/compressor.py
nni/algorithms/compression/v2/pytorch/base/compressor.py
+1
-9
nni/algorithms/compression/v2/pytorch/utils/pruning.py
nni/algorithms/compression/v2/pytorch/utils/pruning.py
+14
-2
No files found.
nni/algorithms/compression/v2/pytorch/base/compressor.py
View file @
abb4dfdb
...
...
@@ -9,7 +9,7 @@ import torch
from
torch.nn
import
Module
from
nni.common.graph_utils
import
TorchModuleGraph
from
nni.algorithms.compression.v2.pytorch.utils
import
get_module_by_name
from
nni.algorithms.compression.v2.pytorch.utils
.pruning
import
get_module_by_name
,
weighted_modules
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -32,14 +32,6 @@ def _setattr(model: Module, name: str, module: Module):
raise
'{} not exist.'
.
format
(
name
)
weighted_modules
=
[
'Conv1d'
,
'Conv2d'
,
'Conv3d'
,
'ConvTranspose1d'
,
'ConvTranspose2d'
,
'ConvTranspose3d'
,
'Linear'
,
'Bilinear'
,
'PReLU'
,
'Embedding'
,
'EmbeddingBag'
,
]
class
Compressor
:
"""
The abstract base pytorch compressor.
...
...
nni/algorithms/compression/v2/pytorch/utils/pruning.py
View file @
abb4dfdb
...
...
@@ -8,6 +8,13 @@ import torch
from
torch
import
Tensor
from
torch.nn
import
Module
weighted_modules
=
[
'Conv1d'
,
'Conv2d'
,
'Conv3d'
,
'ConvTranspose1d'
,
'ConvTranspose2d'
,
'ConvTranspose3d'
,
'Linear'
,
'Bilinear'
,
'PReLU'
,
'Embedding'
,
'EmbeddingBag'
,
]
def
config_list_canonical
(
model
:
Module
,
config_list
:
List
[
Dict
])
->
List
[
Dict
]:
'''
...
...
@@ -37,6 +44,12 @@ def config_list_canonical(model: Module, config_list: List[Dict]) -> List[Dict]:
else
:
config
[
'sparsity_per_layer'
]
=
config
.
pop
(
'sparsity'
)
for
config
in
config_list
:
if
'op_types'
in
config
:
if
'default'
in
config
[
'op_types'
]:
config
[
'op_types'
].
remove
(
'default'
)
config
[
'op_types'
].
extend
(
weighted_modules
)
for
config
in
config_list
:
if
'op_partial_names'
in
config
:
op_names
=
[]
...
...
@@ -225,18 +238,17 @@ def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[
model_weights_numel
[
module_name
]
=
module
.
weight
.
data
.
numel
()
return
model_weights_numel
,
masked_rate
# FIXME: to avoid circular import, copy this function in this place
def
get_module_by_name
(
model
,
module_name
):
"""
Get a module specified by its module name
Parameters
----------
model : pytorch model
the pytorch model from which to get its module
module_name : str
the name of the required module
Returns
-------
module, module
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment