Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
9e5e0e3c
"...resnet50_tensorflow.git" did not exist on "ec54319106c913c503517ac78882061e8b8a9607"
Unverified
Commit
9e5e0e3c
authored
Dec 25, 2020
by
J-shang
Committed by
GitHub
Dec 25, 2020
Browse files
fix pruner doc typo and example bug (#3223)
parent
0c13ea49
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
7 deletions
+7
-7
docs/en_US/Compression/Pruner.rst
docs/en_US/Compression/Pruner.rst
+2
-2
examples/model_compress/auto_pruners_torch.py
examples/model_compress/auto_pruners_torch.py
+5
-5
No files found.
docs/en_US/Compression/Pruner.rst
View file @
9e5e0e3c
...
@@ -582,7 +582,7 @@ PyTorch code
...
@@ -582,7 +582,7 @@ PyTorch code
.. code-block:: python
.. code-block:: python
from nni.algorithms.compression.pytorch.pruning import A
DMM
Pruner
from nni.algorithms.compression.pytorch.pruning import A
utoCompress
Pruner
config_list = [{
config_list = [{
'sparsity': 0.5,
'sparsity': 0.5,
'op_types': ['Conv2d']
'op_types': ['Conv2d']
...
@@ -633,7 +633,7 @@ PyTorch code
...
@@ -633,7 +633,7 @@ PyTorch code
You can view :githublink:`example <examples/model_compress/amc/>` for more information.
You can view :githublink:`example <examples/model_compress/amc/>` for more information.
User configuration for A
utoCompress
Pruner
User configuration for A
MC
Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
**PyTorch**
**PyTorch**
...
...
examples/model_compress/auto_pruners_torch.py
View file @
9e5e0e3c
...
@@ -229,7 +229,7 @@ def main(args):
...
@@ -229,7 +229,7 @@ def main(args):
# used to save the performance of the original & pruned & finetuned models
# used to save the performance of the original & pruned & finetuned models
result
=
{
'flops'
:
{},
'params'
:
{},
'performance'
:{}}
result
=
{
'flops'
:
{},
'params'
:
{},
'performance'
:{}}
flops
,
params
=
count_flops_params
(
model
,
get_input_size
(
args
.
dataset
))
flops
,
params
,
_
=
count_flops_params
(
model
,
get_input_size
(
args
.
dataset
))
result
[
'flops'
][
'original'
]
=
flops
result
[
'flops'
][
'original'
]
=
flops
result
[
'params'
][
'original'
]
=
params
result
[
'params'
][
'original'
]
=
params
...
@@ -238,7 +238,7 @@ def main(args):
...
@@ -238,7 +238,7 @@ def main(args):
result
[
'performance'
][
'original'
]
=
evaluation_result
result
[
'performance'
][
'original'
]
=
evaluation_result
# module types to prune, only "Conv2d" supported for channel pruning
# module types to prune, only "Conv2d" supported for channel pruning
if
args
.
base_algo
in
[
'l1'
,
'l2'
]:
if
args
.
base_algo
in
[
'l1'
,
'l2'
,
'fpgm'
]:
op_types
=
[
'Conv2d'
]
op_types
=
[
'Conv2d'
]
elif
args
.
base_algo
==
'level'
:
elif
args
.
base_algo
==
'level'
:
op_types
=
[
'default'
]
op_types
=
[
'default'
]
...
@@ -261,7 +261,7 @@ def main(args):
...
@@ -261,7 +261,7 @@ def main(args):
elif
args
.
pruner
==
'ADMMPruner'
:
elif
args
.
pruner
==
'ADMMPruner'
:
# users are free to change the config here
# users are free to change the config here
if
args
.
model
==
'LeNet'
:
if
args
.
model
==
'LeNet'
:
if
args
.
base_algo
in
[
'l1'
,
'l2'
]:
if
args
.
base_algo
in
[
'l1'
,
'l2'
,
'fpgm'
]:
config_list
=
[{
config_list
=
[{
'sparsity'
:
0.8
,
'sparsity'
:
0.8
,
'op_types'
:
[
'Conv2d'
],
'op_types'
:
[
'Conv2d'
],
...
@@ -337,7 +337,7 @@ def main(args):
...
@@ -337,7 +337,7 @@ def main(args):
torch
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
args
.
experiment_data_dir
,
'model_speed_up.pth'
))
torch
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
args
.
experiment_data_dir
,
'model_speed_up.pth'
))
print
(
'Speed up model saved to %s'
,
args
.
experiment_data_dir
)
print
(
'Speed up model saved to %s'
,
args
.
experiment_data_dir
)
flops
,
params
=
count_flops_params
(
model
,
get_input_size
(
args
.
dataset
))
flops
,
params
,
_
=
count_flops_params
(
model
,
get_input_size
(
args
.
dataset
))
result
[
'flops'
][
'speedup'
]
=
flops
result
[
'flops'
][
'speedup'
]
=
flops
result
[
'params'
][
'speedup'
]
=
params
result
[
'params'
][
'speedup'
]
=
params
...
@@ -414,7 +414,7 @@ if __name__ == '__main__':
...
@@ -414,7 +414,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--pruner'
,
type
=
str
,
default
=
'SimulatedAnnealingPruner'
,
parser
.
add_argument
(
'--pruner'
,
type
=
str
,
default
=
'SimulatedAnnealingPruner'
,
help
=
'pruner to use'
)
help
=
'pruner to use'
)
parser
.
add_argument
(
'--base-algo'
,
type
=
str
,
default
=
'l1'
,
parser
.
add_argument
(
'--base-algo'
,
type
=
str
,
default
=
'l1'
,
help
=
'base pruning algorithm. level, l1
or l2
'
)
help
=
'base pruning algorithm. level, l1
, l2, or fpgm
'
)
parser
.
add_argument
(
'--sparsity'
,
type
=
float
,
default
=
0.1
,
parser
.
add_argument
(
'--sparsity'
,
type
=
float
,
default
=
0.1
,
help
=
'target overall target sparsity'
)
help
=
'target overall target sparsity'
)
# param for SimulatedAnnealingPruner
# param for SimulatedAnnealingPruner
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment