Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a16e570d
Unverified
Commit
a16e570d
authored
Sep 24, 2021
by
J-shang
Committed by
GitHub
Sep 24, 2021
Browse files
[Model Compression] Add more Task Generator (#4178)
parent
7a50c96d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
314 additions
and
13 deletions
+314
-13
nni/algorithms/compression/v2/pytorch/pruning/basic_scheduler.py
...orithms/compression/v2/pytorch/pruning/basic_scheduler.py
+71
-5
nni/algorithms/compression/v2/pytorch/pruning/tools/__init__.py
...gorithms/compression/v2/pytorch/pruning/tools/__init__.py
+3
-1
nni/algorithms/compression/v2/pytorch/pruning/tools/task_generator.py
...ms/compression/v2/pytorch/pruning/tools/task_generator.py
+240
-7
No files found.
nni/algorithms/compression/v2/pytorch/pruning/basic_scheduler.py
View file @
a16e570d
...
...
@@ -15,7 +15,8 @@ from .tools import TaskGenerator
class
PruningScheduler
(
BasePruningScheduler
):
def
__init__
(
self
,
pruner
:
Pruner
,
task_generator
:
TaskGenerator
,
finetuner
:
Callable
[[
Module
],
None
]
=
None
,
speed_up
:
bool
=
False
,
dummy_input
:
Tensor
=
None
,
evaluator
:
Optional
[
Callable
[[
Module
],
float
]]
=
None
):
speed_up
:
bool
=
False
,
dummy_input
:
Tensor
=
None
,
evaluator
:
Optional
[
Callable
[[
Module
],
float
]]
=
None
,
reset_weight
:
bool
=
False
):
"""
Parameters
----------
...
...
@@ -33,6 +34,8 @@ class PruningScheduler(BasePruningScheduler):
evaluator
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
reset_weight
If set True, the model weight will reset to the origin model weight at the end of each iteration step.
"""
self
.
pruner
=
pruner
self
.
task_generator
=
task_generator
...
...
@@ -40,6 +43,7 @@ class PruningScheduler(BasePruningScheduler):
self
.
speed_up
=
speed_up
self
.
dummy_input
=
dummy_input
self
.
evaluator
=
evaluator
self
.
reset_weight
=
reset_weight
def
generate_task
(
self
)
->
Optional
[
Task
]:
return
self
.
task_generator
.
next
()
...
...
@@ -47,12 +51,15 @@ class PruningScheduler(BasePruningScheduler):
def
record_task_result
(
self
,
task_result
:
TaskResult
):
self
.
task_generator
.
receive_task_result
(
task_result
)
def
pruning_one_step
(
self
,
task
:
Task
)
->
TaskResult
:
def
pruning_one_step_normal
(
self
,
task
:
Task
)
->
TaskResult
:
"""
generate masks -> speed up -> finetune -> evaluate
"""
model
,
masks
,
config_list
=
task
.
load_data
()
# pruning model
self
.
pruner
.
reset
(
model
,
config_list
)
self
.
pruner
.
load_masks
(
masks
)
# pruning model
compact_model
,
pruner_generated_masks
=
self
.
pruner
.
compress
()
compact_model_masks
=
deepcopy
(
pruner_generated_masks
)
...
...
@@ -75,12 +82,71 @@ class PruningScheduler(BasePruningScheduler):
self
.
pruner
.
_unwrap_model
()
# evaluate
score
=
self
.
evaluator
(
compact_model
)
if
self
.
evaluator
is
not
None
else
None
if
self
.
evaluator
is
not
None
:
if
self
.
speed_up
:
score
=
self
.
evaluator
(
compact_model
)
else
:
self
.
pruner
.
_wrap_model
()
score
=
self
.
evaluator
(
compact_model
)
self
.
pruner
.
_unwrap_model
()
else
:
score
=
None
# clear model references
self
.
pruner
.
clear_model_references
()
return
TaskResult
(
task
.
task_id
,
compact_model
,
compact_model_masks
,
pruner_generated_masks
,
score
)
def
pruning_one_step_reset_weight
(
self
,
task
:
Task
)
->
TaskResult
:
"""
finetune -> generate masks -> reset weight -> speed up -> evaluate
"""
model
,
masks
,
config_list
=
task
.
load_data
()
checkpoint
=
deepcopy
(
model
.
state_dict
())
self
.
pruner
.
reset
(
model
,
config_list
)
self
.
pruner
.
load_masks
(
masks
)
# finetune
if
self
.
finetuner
is
not
None
:
self
.
finetuner
(
model
)
# pruning model
compact_model
,
pruner_generated_masks
=
self
.
pruner
.
compress
()
compact_model_masks
=
deepcopy
(
pruner_generated_masks
)
# show the pruning effect
self
.
pruner
.
show_pruned_weights
()
self
.
pruner
.
_unwrap_model
()
# reset model weight
compact_model
.
load_state_dict
(
checkpoint
)
# speed up
if
self
.
speed_up
:
ModelSpeedup
(
compact_model
,
self
.
dummy_input
,
pruner_generated_masks
).
speedup_model
()
compact_model_masks
=
{}
# evaluate
if
self
.
evaluator
is
not
None
:
if
self
.
speed_up
:
score
=
self
.
evaluator
(
compact_model
)
else
:
self
.
pruner
.
_wrap_model
()
score
=
self
.
evaluator
(
compact_model
)
self
.
pruner
.
_unwrap_model
()
else
:
score
=
None
# clear model references
self
.
pruner
.
clear_model_references
()
return
TaskResult
(
task
.
task_id
,
compact_model
,
compact_model_masks
,
pruner_generated_masks
,
score
)
def
pruning_one_step
(
self
,
task
:
Task
)
->
TaskResult
:
if
self
.
reset_weight
:
return
self
.
pruning_one_step_reset_weight
(
task
)
else
:
return
self
.
pruning_one_step_normal
(
task
)
def
get_best_result
(
self
)
->
Optional
[
Tuple
[
int
,
Module
,
Dict
[
str
,
Dict
[
str
,
Tensor
]],
float
,
List
[
Dict
]]]:
return
self
.
task_generator
.
get_best_result
()
nni/algorithms/compression/v2/pytorch/pruning/tools/__init__.py
View file @
a16e570d
...
...
@@ -24,5 +24,7 @@ from .sparsity_allocator import (
)
from
.task_generator
import
(
AGPTaskGenerator
,
LinearTaskGenerator
LinearTaskGenerator
,
LotteryTicketTaskGenerator
,
SimulatedAnnealingTaskGenerator
)
nni/algorithms/compression/v2/pytorch/pruning/tools/task_generator.py
View file @
a16e570d
...
...
@@ -4,15 +4,20 @@
from
copy
import
deepcopy
import
logging
from
pathlib
import
Path
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
,
Tuple
import
json_tricks
import
numpy
as
np
from
torch
import
Tensor
import
torch
from
torch.nn
import
Module
from
nni.algorithms.compression.v2.pytorch.base
import
Task
,
TaskResult
from
nni.algorithms.compression.v2.pytorch.utils.pruning
import
config_list_canonical
,
compute_sparsity
from
nni.algorithms.compression.v2.pytorch.utils.pruning
import
(
config_list_canonical
,
compute_sparsity
,
get_model_weights_numel
)
from
.base
import
TaskGenerator
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -21,6 +26,23 @@ _logger = logging.getLogger(__name__)
class
FunctionBasedTaskGenerator
(
TaskGenerator
):
def
__init__
(
self
,
total_iteration
:
int
,
origin_model
:
Module
,
origin_config_list
:
List
[
Dict
],
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
log_dir
:
str
=
'.'
,
keep_intermidiate_result
:
bool
=
False
):
"""
Parameters
----------
total_iteration
The total iteration number.
origin_model
The origin unwrapped pytorch model to be pruned.
origin_config_list
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
origin_masks
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
log_dir
The log directory use to saving the task generator log.
keep_intermidiate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
self
.
current_iteration
=
0
self
.
target_sparsity
=
config_list_canonical
(
origin_model
,
origin_config_list
)
self
.
total_iteration
=
total_iteration
...
...
@@ -54,7 +76,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
self
.
_tasks
[
task_result
.
task_id
].
state
[
'current2origin_sparsity'
]
=
current2origin_sparsity
# if reach the total_iteration, no more task will be generated
if
self
.
current_iteration
>
=
self
.
total_iteration
:
if
self
.
current_iteration
>
self
.
total_iteration
:
return
[]
task_id
=
self
.
_task_id_candidate
...
...
@@ -77,9 +99,9 @@ class FunctionBasedTaskGenerator(TaskGenerator):
class
AGPTaskGenerator
(
FunctionBasedTaskGenerator
):
def
generate_config_list
(
self
,
target_sparsity
:
List
[
Dict
],
iteration
:
int
,
model_based
_sparsity
:
List
[
Dict
])
->
List
[
Dict
]:
def
generate_config_list
(
self
,
target_sparsity
:
List
[
Dict
],
iteration
:
int
,
compact2origin
_sparsity
:
List
[
Dict
])
->
List
[
Dict
]:
config_list
=
[]
for
target
,
mo
in
zip
(
target_sparsity
,
model_based
_sparsity
):
for
target
,
mo
in
zip
(
target_sparsity
,
compact2origin
_sparsity
):
ori_sparsity
=
(
1
-
(
1
-
iteration
/
self
.
total_iteration
)
**
3
)
*
target
[
'total_sparsity'
]
sparsity
=
max
(
0.0
,
(
ori_sparsity
-
mo
[
'total_sparsity'
])
/
(
1
-
mo
[
'total_sparsity'
]))
assert
0
<=
sparsity
<=
1
,
'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'
.
format
(
sparsity
,
ori_sparsity
,
mo
[
'total_sparsity'
])
...
...
@@ -89,12 +111,223 @@ class AGPTaskGenerator(FunctionBasedTaskGenerator):
class
LinearTaskGenerator
(
FunctionBasedTaskGenerator
):
def
generate_config_list
(
self
,
target_sparsity
:
List
[
Dict
],
iteration
:
int
,
model_based
_sparsity
:
List
[
Dict
])
->
List
[
Dict
]:
def
generate_config_list
(
self
,
target_sparsity
:
List
[
Dict
],
iteration
:
int
,
compact2origin
_sparsity
:
List
[
Dict
])
->
List
[
Dict
]:
config_list
=
[]
for
target
,
mo
in
zip
(
target_sparsity
,
model_based
_sparsity
):
for
target
,
mo
in
zip
(
target_sparsity
,
compact2origin
_sparsity
):
ori_sparsity
=
iteration
/
self
.
total_iteration
*
target
[
'total_sparsity'
]
sparsity
=
max
(
0.0
,
(
ori_sparsity
-
mo
[
'total_sparsity'
])
/
(
1
-
mo
[
'total_sparsity'
]))
assert
0
<=
sparsity
<=
1
,
'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'
.
format
(
sparsity
,
ori_sparsity
,
mo
[
'total_sparsity'
])
config_list
.
append
(
deepcopy
(
target
))
config_list
[
-
1
][
'total_sparsity'
]
=
sparsity
return
config_list
class
LotteryTicketTaskGenerator
(
FunctionBasedTaskGenerator
):
def
__init__
(
self
,
total_iteration
:
int
,
origin_model
:
Module
,
origin_config_list
:
List
[
Dict
],
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
log_dir
:
str
=
'.'
,
keep_intermidiate_result
:
bool
=
False
):
super
().
__init__
(
total_iteration
,
origin_model
,
origin_config_list
,
origin_masks
=
origin_masks
,
log_dir
=
log_dir
,
keep_intermidiate_result
=
keep_intermidiate_result
)
self
.
current_iteration
=
1
def
generate_config_list
(
self
,
target_sparsity
:
List
[
Dict
],
iteration
:
int
,
compact2origin_sparsity
:
List
[
Dict
])
->
List
[
Dict
]:
config_list
=
[]
for
target
,
mo
in
zip
(
target_sparsity
,
compact2origin_sparsity
):
# NOTE: The ori_sparsity calculation formula in compression v1 is as follow, it is different from the paper.
# But the formula in paper will cause numerical problems, so keep the formula in compression v1.
ori_sparsity
=
1
-
(
1
-
target
[
'total_sparsity'
])
**
(
iteration
/
self
.
total_iteration
)
# The following is the formula in paper.
# ori_sparsity = (target['total_sparsity'] * 100) ** (iteration / self.total_iteration) / 100
sparsity
=
max
(
0.0
,
(
ori_sparsity
-
mo
[
'total_sparsity'
])
/
(
1
-
mo
[
'total_sparsity'
]))
assert
0
<=
sparsity
<=
1
,
'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'
.
format
(
sparsity
,
ori_sparsity
,
mo
[
'total_sparsity'
])
config_list
.
append
(
deepcopy
(
target
))
config_list
[
-
1
][
'total_sparsity'
]
=
sparsity
return
config_list
class
SimulatedAnnealingTaskGenerator
(
TaskGenerator
):
def
__init__
(
self
,
origin_model
:
Module
,
origin_config_list
:
List
[
Dict
],
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
start_temperature
:
float
=
100
,
stop_temperature
:
float
=
20
,
cool_down_rate
:
float
=
0.9
,
perturbation_magnitude
:
float
=
0.35
,
log_dir
:
str
=
'.'
,
keep_intermidiate_result
:
bool
=
False
):
"""
Parameters
----------
origin_model
The origin unwrapped pytorch model to be pruned.
origin_config_list
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
origin_masks
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
start_temperature
Start temperature of the simulated annealing process.
stop_temperature
Stop temperature of the simulated annealing process.
cool_down_rate
Cool down rate of the temperature.
perturbation_magnitude
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
log_dir
The log directory use to saving the task generator log.
keep_intermidiate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
self
.
start_temperature
=
start_temperature
self
.
current_temperature
=
start_temperature
self
.
stop_temperature
=
stop_temperature
self
.
cool_down_rate
=
cool_down_rate
self
.
perturbation_magnitude
=
perturbation_magnitude
self
.
weights_numel
,
self
.
masked_rate
=
get_model_weights_numel
(
origin_model
,
origin_config_list
,
origin_masks
)
self
.
target_sparsity_list
=
config_list_canonical
(
origin_model
,
origin_config_list
)
self
.
_adjust_target_sparsity
()
self
.
_temp_config_list
=
None
self
.
_current_sparsity_list
=
None
self
.
_current_score
=
None
super
().
__init__
(
origin_model
,
origin_masks
=
origin_masks
,
origin_config_list
=
origin_config_list
,
log_dir
=
log_dir
,
keep_intermidiate_result
=
keep_intermidiate_result
)
def
_adjust_target_sparsity
(
self
):
"""
If origin_masks is not empty, then re-scale the target sparsity.
"""
if
len
(
self
.
masked_rate
)
>
0
:
for
config
in
self
.
target_sparsity_list
:
sparsity
,
op_names
=
config
[
'total_sparsity'
],
config
[
'op_names'
]
remaining_weight_numel
=
0
pruned_weight_numel
=
0
for
name
in
op_names
:
remaining_weight_numel
+=
self
.
weights_numel
[
name
]
if
name
in
self
.
masked_rate
:
pruned_weight_numel
+=
1
/
(
1
/
self
.
masked_rate
[
name
]
-
1
)
*
self
.
weights_numel
[
name
]
config
[
'total_sparsity'
]
=
max
(
0
,
sparsity
-
pruned_weight_numel
/
(
pruned_weight_numel
+
remaining_weight_numel
))
def
_init_temp_config_list
(
self
):
self
.
_temp_config_list
=
[]
self
.
_temp_sparsity_list
=
[]
for
config
in
self
.
target_sparsity_list
:
sparsity_config
,
sparsity
=
self
.
_init_config_sparsity
(
config
)
self
.
_temp_config_list
.
extend
(
sparsity_config
)
self
.
_temp_sparsity_list
.
append
(
sparsity
)
def
_init_config_sparsity
(
self
,
config
:
Dict
)
->
Tuple
[
List
[
Dict
],
List
]:
assert
'total_sparsity'
in
config
,
'Sparsity must be set in config: {}'
.
format
(
config
)
target_sparsity
=
config
[
'total_sparsity'
]
op_names
=
config
[
'op_names'
]
if
target_sparsity
==
0
:
return
[],
[]
while
True
:
random_sparsity
=
sorted
(
np
.
random
.
uniform
(
0
,
1
,
len
(
op_names
)))
rescaled_sparsity
=
self
.
_rescale_sparsity
(
random_sparsity
,
target_sparsity
,
op_names
)
if
rescaled_sparsity
is
not
None
and
rescaled_sparsity
[
0
]
>=
0
and
rescaled_sparsity
[
-
1
]
<
1
:
break
return
self
.
_sparsity_to_config_list
(
rescaled_sparsity
,
config
),
rescaled_sparsity
def
_rescale_sparsity
(
self
,
random_sparsity
:
List
,
target_sparsity
:
float
,
op_names
:
List
)
->
List
:
assert
len
(
random_sparsity
)
==
len
(
op_names
)
num_weights
=
sorted
([
self
.
weights_numel
[
op_name
]
for
op_name
in
op_names
])
sparsity
=
sorted
(
random_sparsity
)
total_weights
=
0
total_weights_pruned
=
0
# calculate the scale
for
idx
,
num_weight
in
enumerate
(
num_weights
):
total_weights
+=
num_weight
total_weights_pruned
+=
int
(
num_weight
*
sparsity
[
idx
])
if
total_weights_pruned
==
0
:
return
None
scale
=
target_sparsity
/
(
total_weights_pruned
/
total_weights
)
# rescale the sparsity
sparsity
=
np
.
asarray
(
sparsity
)
*
scale
return
sparsity
def
_sparsity_to_config_list
(
self
,
sparsity
:
List
,
config
:
Dict
)
->
List
[
Dict
]:
sparsity
=
sorted
(
sparsity
)
op_names
=
[
k
for
k
,
_
in
sorted
(
self
.
weights_numel
.
items
(),
key
=
lambda
item
:
item
[
1
])
if
k
in
config
[
'op_names'
]]
assert
len
(
sparsity
)
==
len
(
op_names
)
return
[{
'total_sparsity'
:
sparsity
,
'op_names'
:
[
op_name
]}
for
sparsity
,
op_name
in
zip
(
sparsity
,
op_names
)]
def
_update_with_perturbations
(
self
):
self
.
_temp_config_list
=
[]
self
.
_temp_sparsity_list
=
[]
# decrease magnitude with current temperature
magnitude
=
self
.
current_temperature
/
self
.
start_temperature
*
self
.
perturbation_magnitude
for
config
,
current_sparsity
in
zip
(
self
.
target_sparsity_list
,
self
.
_current_sparsity_list
):
if
len
(
current_sparsity
)
==
0
:
self
.
_temp_sparsity_list
.
append
([])
continue
while
True
:
perturbation
=
np
.
random
.
uniform
(
-
magnitude
,
magnitude
,
len
(
current_sparsity
))
temp_sparsity
=
np
.
clip
(
0
,
current_sparsity
+
perturbation
,
None
)
temp_sparsity
=
self
.
_rescale_sparsity
(
temp_sparsity
,
config
[
'total_sparsity'
],
config
[
'op_names'
])
if
temp_sparsity
is
not
None
and
temp_sparsity
[
0
]
>=
0
and
temp_sparsity
[
-
1
]
<
1
:
self
.
_temp_config_list
.
extend
(
self
.
_sparsity_to_config_list
(
temp_sparsity
,
config
))
self
.
_temp_sparsity_list
.
append
(
temp_sparsity
)
break
def
_recover_real_sparsity
(
self
,
config_list
:
List
[
Dict
])
->
List
[
Dict
]:
"""
If the origin masks is not None, then the sparsity in new generated config_list need to be rescaled.
"""
for
config
in
config_list
:
assert
len
(
config
[
'op_names'
])
==
1
op_name
=
config
[
'op_names'
][
0
]
if
op_name
in
self
.
masked_rate
:
config
[
'total_sparsity'
]
=
self
.
masked_rate
[
op_name
]
+
config
[
'total_sparsity'
]
*
(
1
-
self
.
masked_rate
[
op_name
])
return
config_list
def
init_pending_tasks
(
self
)
->
List
[
Task
]:
origin_model
=
torch
.
load
(
self
.
_origin_model_path
)
origin_masks
=
torch
.
load
(
self
.
_origin_masks_path
)
self
.
temp_model_path
=
Path
(
self
.
_intermidiate_result_dir
,
'origin_compact_model.pth'
)
self
.
temp_masks_path
=
Path
(
self
.
_intermidiate_result_dir
,
'origin_compact_model_masks.pth'
)
torch
.
save
(
origin_model
,
self
.
temp_model_path
)
torch
.
save
(
origin_masks
,
self
.
temp_masks_path
)
task_result
=
TaskResult
(
'origin'
,
origin_model
,
origin_masks
,
origin_masks
,
None
)
return
self
.
generate_tasks
(
task_result
)
def
generate_tasks
(
self
,
task_result
:
TaskResult
)
->
List
[
Task
]:
# initial/update temp config list
if
self
.
_temp_config_list
is
None
:
self
.
_init_temp_config_list
()
else
:
score
=
self
.
_tasks
[
task_result
.
task_id
].
score
if
self
.
_current_sparsity_list
is
None
:
self
.
_current_sparsity_list
=
deepcopy
(
self
.
_temp_sparsity_list
)
self
.
_current_score
=
score
else
:
delta_E
=
np
.
abs
(
score
-
self
.
_current_score
)
probability
=
np
.
exp
(
-
1
*
delta_E
/
self
.
current_temperature
)
if
self
.
_current_score
<
score
or
np
.
random
.
uniform
(
0
,
1
)
<
probability
:
self
.
_current_score
=
score
self
.
_current_sparsity_list
=
deepcopy
(
self
.
_temp_sparsity_list
)
self
.
current_temperature
*=
self
.
cool_down_rate
if
self
.
current_temperature
<
self
.
stop_temperature
:
return
[]
self
.
_update_with_perturbations
()
task_id
=
self
.
_task_id_candidate
new_config_list
=
self
.
_recover_real_sparsity
(
deepcopy
(
self
.
_temp_config_list
))
config_list_path
=
Path
(
self
.
_intermidiate_result_dir
,
'{}_config_list.json'
.
format
(
task_id
))
with
Path
(
config_list_path
).
open
(
'w'
)
as
f
:
json_tricks
.
dump
(
new_config_list
,
f
,
indent
=
4
)
task
=
Task
(
task_id
,
self
.
temp_model_path
,
self
.
temp_masks_path
,
config_list_path
)
self
.
_tasks
[
task_id
]
=
task
self
.
_task_id_candidate
+=
1
return
[
task
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment