Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bbcb1677
"vscode:/vscode.git/clone" did not exist on "9790e97979be90eafdccc8345611d9c8cc71eb97"
Unverified
Commit
bbcb1677
authored
Jul 31, 2020
by
Guoxin
Committed by
GitHub
Jul 31, 2020
Browse files
fix SimulatedAnnealingPruner export mask issue (#2736)
parent
143c6615
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
4 deletions
+8
-4
src/sdk/pynni/nni/compression/torch/pruning/simulated_annealing_pruner.py
...i/compression/torch/pruning/simulated_annealing_pruner.py
+8
-4
No files found.
src/sdk/pynni/nni/compression/torch/pruning/simulated_annealing_pruner.py
View file @
bbcb1677
...
...
@@ -243,13 +243,11 @@ class SimulatedAnnealingPruner(Pruner):
_logger
.
info
(
'current perturation magnitude:%s'
,
magnitude
)
while
True
:
perturbation
=
np
.
random
.
uniform
(
-
magnitude
,
magnitude
,
len
(
self
.
get_modules_wrapper
()))
perturbation
=
np
.
random
.
uniform
(
-
magnitude
,
magnitude
,
len
(
self
.
get_modules_wrapper
()))
sparsities
=
np
.
clip
(
0
,
self
.
_sparsities
+
perturbation
,
None
)
_logger
.
debug
(
"sparsities before rescalling:%s"
,
sparsities
)
sparsities
=
self
.
_rescale_sparsities
(
sparsities
,
target_sparsity
=
self
.
_sparsity
)
sparsities
=
self
.
_rescale_sparsities
(
sparsities
,
target_sparsity
=
self
.
_sparsity
)
_logger
.
debug
(
"sparsities after rescalling:%s"
,
sparsities
)
if
sparsities
is
not
None
and
sparsities
[
0
]
>=
0
and
sparsities
[
-
1
]
<
1
:
...
...
@@ -312,6 +310,8 @@ class SimulatedAnnealingPruner(Pruner):
# save the overall best masked model
self
.
bound_model
=
model_masked
# the ops with sparsity 0 are not included in this modules_wrapper
modules_wrapper_final
=
pruner
.
get_modules_wrapper
()
break
# if not, accept with probability e^(-deltaE/current_temperature)
else
:
...
...
@@ -356,4 +356,8 @@ class SimulatedAnnealingPruner(Pruner):
if
return_config_list
:
return
self
.
_best_config_list
# This should be done only at the final stage,
# because the modules_wrapper with all the ops are used during the annealing process
self
.
modules_wrapper
=
modules_wrapper_final
return
self
.
bound_model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment