Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
2772751d
"ts/webui/src/static/style/App.scss" did not exist on "1bfc7acf3ed4f910b0db4e436a216bc6663545aa"
Unverified
Commit
2772751d
authored
Jan 12, 2022
by
J-shang
Committed by
GitHub
Jan 12, 2022
Browse files
fix SA task generator bug (#4457)
parent
31f11f51
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
4 deletions
+11
-4
nni/algorithms/compression/v2/pytorch/pruning/tools/task_generator.py
...ms/compression/v2/pytorch/pruning/tools/task_generator.py
+11
-4
No files found.
nni/algorithms/compression/v2/pytorch/pruning/tools/task_generator.py
View file @
2772751d
...
@@ -217,8 +217,8 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
...
@@ -217,8 +217,8 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
self
.
_temp_config_list
=
[]
self
.
_temp_config_list
=
[]
self
.
_temp_sparsity_list
=
[]
self
.
_temp_sparsity_list
=
[]
for
config
in
self
.
target_sparsity_list
:
for
config
in
self
.
target_sparsity_list
:
sparsity_config
,
sparsity
=
self
.
_init_config_sparsity
(
config
)
sparsity_config
_list
,
sparsity
=
self
.
_init_config_sparsity
(
config
)
self
.
_temp_config_list
.
extend
(
sparsity_config
)
self
.
_temp_config_list
.
extend
(
sparsity_config
_list
)
self
.
_temp_sparsity_list
.
append
(
sparsity
)
self
.
_temp_sparsity_list
.
append
(
sparsity
)
def
_init_config_sparsity
(
self
,
config
:
Dict
)
->
Tuple
[
List
[
Dict
],
List
]:
def
_init_config_sparsity
(
self
,
config
:
Dict
)
->
Tuple
[
List
[
Dict
],
List
]:
...
@@ -227,7 +227,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
...
@@ -227,7 +227,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
op_names
=
config
[
'op_names'
]
op_names
=
config
[
'op_names'
]
if
target_sparsity
==
0
:
if
target_sparsity
==
0
:
return
[],
[]
sparsity_config_list
=
[
deepcopy
(
config
)
for
i
in
range
(
len
(
op_names
))]
for
sparsity_config
,
op_name
in
zip
(
sparsity_config_list
,
op_names
):
sparsity_config
.
update
({
'total_sparsity'
:
0
,
'op_names'
:
[
op_name
]})
return
sparsity_config_list
,
[]
low_limit
=
0
low_limit
=
0
while
True
:
while
True
:
...
@@ -266,7 +269,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
...
@@ -266,7 +269,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
sparsity
=
sorted
(
sparsity
)
sparsity
=
sorted
(
sparsity
)
op_names
=
[
k
for
k
,
_
in
sorted
(
self
.
weights_numel
.
items
(),
key
=
lambda
item
:
item
[
1
])
if
k
in
config
[
'op_names'
]]
op_names
=
[
k
for
k
,
_
in
sorted
(
self
.
weights_numel
.
items
(),
key
=
lambda
item
:
item
[
1
])
if
k
in
config
[
'op_names'
]]
assert
len
(
sparsity
)
==
len
(
op_names
)
assert
len
(
sparsity
)
==
len
(
op_names
)
return
[{
'total_sparsity'
:
sparsity
,
'op_names'
:
[
op_name
]}
for
sparsity
,
op_name
in
zip
(
sparsity
,
op_names
)]
sub_temp_config_list
=
[
deepcopy
(
config
)
for
i
in
range
(
len
(
op_names
))]
for
temp_config
,
sp
,
op_name
in
zip
(
sub_temp_config_list
,
sparsity
,
op_names
):
temp_config
.
update
({
'total_sparsity'
:
sp
,
'op_names'
:
[
op_name
]})
return
sub_temp_config_list
def
_update_with_perturbations
(
self
):
def
_update_with_perturbations
(
self
):
self
.
_temp_config_list
=
[]
self
.
_temp_config_list
=
[]
...
@@ -275,6 +281,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
...
@@ -275,6 +281,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
magnitude
=
self
.
current_temperature
/
self
.
start_temperature
*
self
.
perturbation_magnitude
magnitude
=
self
.
current_temperature
/
self
.
start_temperature
*
self
.
perturbation_magnitude
for
config
,
current_sparsity
in
zip
(
self
.
target_sparsity_list
,
self
.
_current_sparsity_list
):
for
config
,
current_sparsity
in
zip
(
self
.
target_sparsity_list
,
self
.
_current_sparsity_list
):
if
len
(
current_sparsity
)
==
0
:
if
len
(
current_sparsity
)
==
0
:
self
.
_temp_config_list
.
extend
(
deepcopy
(
config
))
self
.
_temp_sparsity_list
.
append
([])
self
.
_temp_sparsity_list
.
append
([])
continue
continue
while
True
:
while
True
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment