Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e98ebcf0
Unverified
Commit
e98ebcf0
authored
Sep 10, 2021
by
J-shang
Committed by
GitHub
Sep 10, 2021
Browse files
[Model Compression] Pruning Scheduler (#4089)
parent
04f439a0
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
673 additions
and
47 deletions
+673
-47
nni/algorithms/compression/v2/pytorch/base/__init__.py
nni/algorithms/compression/v2/pytorch/base/__init__.py
+1
-0
nni/algorithms/compression/v2/pytorch/base/compressor.py
nni/algorithms/compression/v2/pytorch/base/compressor.py
+11
-0
nni/algorithms/compression/v2/pytorch/base/pruner.py
nni/algorithms/compression/v2/pytorch/base/pruner.py
+15
-25
nni/algorithms/compression/v2/pytorch/base/scheduler.py
nni/algorithms/compression/v2/pytorch/base/scheduler.py
+184
-0
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+7
-7
nni/algorithms/compression/v2/pytorch/pruning/basic_scheduler.py
...orithms/compression/v2/pytorch/pruning/basic_scheduler.py
+86
-0
nni/algorithms/compression/v2/pytorch/pruning/tools/__init__.py
...gorithms/compression/v2/pytorch/pruning/tools/__init__.py
+6
-1
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
+142
-9
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
...ompression/v2/pytorch/pruning/tools/sparsity_allocator.py
+8
-4
nni/algorithms/compression/v2/pytorch/pruning/tools/task_generator.py
...ms/compression/v2/pytorch/pruning/tools/task_generator.py
+100
-0
nni/algorithms/compression/v2/pytorch/utils/pruning.py
nni/algorithms/compression/v2/pytorch/utils/pruning.py
+113
-1
No files found.
nni/algorithms/compression/v2/pytorch/base/__init__.py
View file @
e98ebcf0
from
.compressor
import
Compressor
,
LayerInfo
from
.compressor
import
Compressor
,
LayerInfo
from
.pruner
import
Pruner
,
PrunerModuleWrapper
from
.pruner
import
Pruner
,
PrunerModuleWrapper
from
.scheduler
import
BasePruningScheduler
,
Task
,
TaskResult
nni/algorithms/compression/v2/pytorch/base/compressor.py
View file @
e98ebcf0
...
@@ -84,6 +84,17 @@ class Compressor:
...
@@ -84,6 +84,17 @@ class Compressor:
self
.
_wrap_model
()
self
.
_wrap_model
()
def
clear_model_references
(
self
):
"""
Clear all references to the model in this compressor. Just to free up memory.
Need reset first before the next time call compressor function.
"""
self
.
_unwrap_model
()
self
.
bound_model
=
None
self
.
config_list
=
None
self
.
modules_wrapper
=
None
self
.
_modules_to_compress
=
None
def
_detect_modules_to_compress
(
self
)
->
List
[
Tuple
[
LayerInfo
,
Dict
]]:
def
_detect_modules_to_compress
(
self
)
->
List
[
Tuple
[
LayerInfo
,
Dict
]]:
"""
"""
Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
...
...
nni/algorithms/compression/v2/pytorch/base/pruner.py
View file @
e98ebcf0
...
@@ -87,14 +87,17 @@ class Pruner(Compressor):
...
@@ -87,14 +87,17 @@ class Pruner(Compressor):
Parameters
Parameters
----------
----------
masks
masks
The masks dict with format {'op_name': {'weight
_mask
': mask, 'bias
_mask
': mask}}.
The masks dict with format {'op_name': {'weight': mask, 'bias': mask}}.
"""
"""
wrappers
=
self
.
get_modules_wrapper
()
wrappers
=
self
.
get_modules_wrapper
()
for
name
,
layer_mask
in
masks
.
items
():
for
name
,
layer_mask
in
masks
.
items
():
assert
name
in
wrappers
,
'{} is not in wrappers of this pruner, can not apply the mask.'
.
format
(
name
)
assert
name
in
wrappers
,
'{} is not in wrappers of this pruner, can not apply the mask.'
.
format
(
name
)
for
mask_type
,
mask
in
layer_mask
.
items
():
if
layer_mask
.
get
(
'weight'
)
is
not
None
:
assert
hasattr
(
wrappers
[
name
],
mask_type
),
'there is no attribute {} in wrapper'
.
format
(
mask_type
)
assert
hasattr
(
wrappers
[
name
],
'weight_mask'
),
'There is no attribute weight_mask in wrapper.'
setattr
(
wrappers
[
name
],
mask_type
,
mask
)
setattr
(
wrappers
[
name
],
'weight_mask'
,
layer_mask
.
get
(
'weight'
))
if
layer_mask
.
get
(
'bias'
)
is
not
None
:
assert
hasattr
(
wrappers
[
name
],
'bias_mask'
),
'There is no attribute bias_mask in wrapper.'
setattr
(
wrappers
[
name
],
'bias_mask'
,
layer_mask
.
get
(
'bias'
))
def
compress
(
self
)
->
Tuple
[
Module
,
Dict
[
str
,
Dict
[
str
,
Tensor
]]]:
def
compress
(
self
)
->
Tuple
[
Module
,
Dict
[
str
,
Dict
[
str
,
Tensor
]]]:
"""
"""
...
@@ -126,27 +129,21 @@ class Pruner(Compressor):
...
@@ -126,27 +129,21 @@ class Pruner(Compressor):
index
=
torch
.
nonzero
(
weight_mask
.
abs
().
sum
(
sum_idx
)
!=
0
,
as_tuple
=
False
).
tolist
()
index
=
torch
.
nonzero
(
weight_mask
.
abs
().
sum
(
sum_idx
)
!=
0
,
as_tuple
=
False
).
tolist
()
_logger
.
info
(
f
'simulated prune
{
wrapper
.
name
}
remain/total:
{
len
(
index
)
}
/
{
weight_mask
.
size
(
dim
)
}
'
)
_logger
.
info
(
f
'simulated prune
{
wrapper
.
name
}
remain/total:
{
len
(
index
)
}
/
{
weight_mask
.
size
(
dim
)
}
'
)
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
def
export_model
(
self
,
model_path
:
str
,
mask_path
:
Optional
[
str
]
=
None
):
"""
"""
Export pruned model weights, masks and onnx model(optional)
Export pruned model weights, masks and onnx model(optional)
Parameters
Parameters
----------
----------
model_path
model_path
Path to save pruned model state_dict.
Path to save pruned model state_dict.
The weight and bias have already multiplied the masks.
mask_path
mask_path
(optional) path to save mask dict.
Path to save mask dict.
onnx_path
(optional) path to save onnx model.
input_shape
Input shape to onnx model.
device
Device of the model, used to place the dummy input tensor for exporting onnx file.
The tensor is placed on cpu if ```device``` is None.
"""
"""
assert
model_path
is
not
None
,
'model_path must be specified'
assert
self
.
bound_model
is
not
None
,
'The bound model reference has been cleared.'
assert
model_path
is
not
None
,
'model_path must be specified.'
mask_dict
=
{}
mask_dict
=
{}
self
.
_unwrap_model
()
# used for generating correct state_dict name without wrapper state
self
.
_unwrap_model
()
for
name
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
for
name
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
weight_mask
=
wrapper
.
weight_mask
weight_mask
=
wrapper
.
weight_mask
...
@@ -159,20 +156,13 @@ class Pruner(Compressor):
...
@@ -159,20 +156,13 @@ class Pruner(Compressor):
if
bias_mask
is
not
None
:
if
bias_mask
is
not
None
:
wrapper
.
module
.
bias
.
data
=
wrapper
.
module
.
bias
.
data
.
mul
(
bias_mask
)
wrapper
.
module
.
bias
.
data
=
wrapper
.
module
.
bias
.
data
.
mul
(
bias_mask
)
# save mask to dict
# save mask to dict
mask_dict
[
name
]
=
{
"weight
_mask
"
:
weight_mask
,
"bias
_mask
"
:
bias_mask
}
mask_dict
[
name
]
=
{
"weight"
:
weight_mask
,
"bias"
:
bias_mask
}
torch
.
save
(
self
.
bound_model
.
state_dict
(),
model_path
)
torch
.
save
(
self
.
bound_model
.
state_dict
(),
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
if
mask_path
is
not
None
:
if
mask_path
is
not
None
:
torch
.
save
(
mask_dict
,
mask_path
)
torch
.
save
(
mask_dict
,
mask_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
mask_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
mask_path
)
if
onnx_path
is
not
None
:
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
# input info needed
if
device
is
None
:
device
=
torch
.
device
(
'cpu'
)
input_data
=
torch
.
Tensor
(
*
input_shape
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
.
to
(
device
),
onnx_path
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
self
.
_wrap_model
()
self
.
_wrap_model
()
nni/algorithms/compression/v2/pytorch/base/scheduler.py
0 → 100644
View file @
e98ebcf0
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
gc
import
logging
import
os
from
pathlib
import
Path
from
typing
import
List
,
Dict
,
Tuple
,
Literal
,
Optional
import
json_tricks
import
torch
from
torch.nn
import
Module
from
torch.tensor
import
Tensor
_logger
=
logging
.
getLogger
(
__name__
)
class
Task
:
# NOTE: If we want to support multi-thread, this part need to refactor, maybe use file and lock to sync.
_reference_counter
=
{}
def
__init__
(
self
,
task_id
:
int
,
model_path
:
str
,
masks_path
:
str
,
config_list_path
:
str
)
->
None
:
"""
Parameters
----------
task_id
The unique id of task.
model_path
The path of the unwrapped pytorch model that will be pruned in this task.
masks_path
The path of the masks that applied on the model before pruning.
config_list_path
The path of the config list that used in this task.
"""
self
.
task_id
=
task_id
self
.
model_path
=
model_path
self
.
masks_path
=
masks_path
self
.
config_list_path
=
config_list_path
self
.
status
:
Literal
[
'Pending'
,
'Running'
,
'Finished'
]
=
'Pending'
self
.
score
:
Optional
[
float
]
=
None
self
.
state
=
{}
for
ref
in
self
.
referenced_paths
():
self
.
_reference_counter
.
setdefault
(
ref
,
0
)
self
.
_reference_counter
[
ref
]
+=
1
self
.
_cleaned
=
False
def
to_dict
(
self
)
->
Dict
:
return
{
'task_id'
:
self
.
task_id
,
'model_path'
:
str
(
self
.
model_path
),
'masks_path'
:
str
(
self
.
masks_path
),
'config_list_path'
:
str
(
self
.
config_list_path
),
'status'
:
self
.
status
,
'score'
:
self
.
score
,
'state'
:
self
.
state
}
def
load_data
(
self
)
->
Tuple
[
Module
,
Dict
[
str
,
Dict
[
str
,
Tensor
]],
List
[
Dict
]]:
"""
Returns
-------
Tuple[Module, Dict[str, Dict[str, Tensor]], List[Dict]]
Return the model pruning in this task, the masks of the model before pruning,
the config list used in this task.
"""
model
=
torch
.
load
(
self
.
model_path
)
masks
=
torch
.
load
(
self
.
masks_path
)
with
Path
(
self
.
config_list_path
).
open
(
'r'
)
as
f
:
config_list
=
json_tricks
.
load
(
f
)
return
model
,
masks
,
config_list
def
referenced_paths
(
self
)
->
List
[
str
]:
"""
Return the path list that need to count reference in this task.
"""
return
[
self
.
model_path
,
self
.
masks_path
,
self
.
config_list_path
]
def
clean_up
(
self
):
"""
Counter of referenced file paths subtract 1. If the counter reach 0, then delete the file.
"""
if
not
self
.
_cleaned
:
for
ref
in
self
.
referenced_paths
():
self
.
_reference_counter
[
ref
]
-=
1
if
self
.
_reference_counter
[
ref
]
<=
0
:
os
.
remove
(
ref
)
if
self
.
_reference_counter
[
ref
]
<
0
:
_logger
.
warning
(
'Referance counter error, the number of %s is %d'
,
ref
,
self
.
_reference_counter
[
ref
])
self
.
_cleaned
=
True
else
:
_logger
.
warning
(
'Already clean up task %d'
,
self
.
task_id
)
class
TaskResult
:
def
__init__
(
self
,
task_id
:
int
,
compact_model
:
Module
,
compact_model_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]],
pruner_generated_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]],
score
:
Optional
[
float
])
->
None
:
"""
Parameters
----------
task_id
The unique id of task.
compact_model
The unwrapped compact pytorch model after pruning. If the compact model has been speeduped during the pruning process,
it will have a smaller structure compare with the model before pruning.
If the compact model has not been speeduped, it will have the same structure with the model before pruning.
compact_model_masks
The masks on the compact model. If the compact model has been speeduped during the pruning process,
the `compact_model_masks` is always an empty dict. If the compact model has not been speeduped,
the `compact_model_masks` is same as `pruner_generated_masks`.
pruner_generated_masks
The masks that can apply on the before pruning model. It is always the output of `pruner.compress()`.
TODO: If the compact model has been speeduped, the auto infer masks maybe also need.
score
The score of the pruning effect. i.e., the accuracy or latency after pruning.
"""
self
.
task_id
=
task_id
self
.
compact_model
=
compact_model
self
.
compact_model_masks
=
compact_model_masks
self
.
pruner_generated_masks
=
pruner_generated_masks
self
.
score
=
score
class
BasePruningScheduler
:
def
generate_task
(
self
)
->
Optional
[
Task
]:
"""
Returns
-------
Optional[Task]
Return the next pruning task.
"""
raise
NotImplementedError
()
def
record_task_result
(
self
,
task_result
:
TaskResult
):
"""
Parameters
----------
task_result
The result of the task
"""
raise
NotImplementedError
()
def
pruning_one_step
(
self
,
task
:
Task
)
->
TaskResult
:
"""
Pruning the model defined in task.
Parameters
----------
task
The pruning task in this step.
Returns
-------
TaskResult
Return the result of the task in this step.
"""
raise
NotImplementedError
()
def
get_best_result
(
self
)
->
Tuple
[
int
,
Module
,
Dict
[
str
,
Dict
[
str
,
Tensor
]],
float
,
List
[
Dict
]]:
"""
Returns
-------
Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]
Return the task result that has the best performance,
inculde task id, the compact model, the masks on the compact model, score and config list used in this task.
"""
raise
NotImplementedError
()
def
compress
(
self
):
"""
The pruning schedule main loop.
"""
task
=
self
.
generate_task
()
while
task
is
not
None
:
task_result
=
self
.
pruning_one_step
(
task
)
self
.
record_task_result
(
task_result
)
del
task_result
gc
.
collect
()
task
=
self
.
generate_task
()
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
e98ebcf0
...
@@ -72,7 +72,7 @@ INTERNAL_SCHEMA = {
...
@@ -72,7 +72,7 @@ INTERNAL_SCHEMA = {
}
}
class
OneShot
Pruner
(
Pruner
):
class
Basic
Pruner
(
Pruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
self
.
data_collector
:
DataCollector
=
None
self
.
data_collector
:
DataCollector
=
None
self
.
metrics_calculator
:
MetricsCalculator
=
None
self
.
metrics_calculator
:
MetricsCalculator
=
None
...
@@ -120,7 +120,7 @@ class OneShotPruner(Pruner):
...
@@ -120,7 +120,7 @@ class OneShotPruner(Pruner):
return
self
.
bound_model
,
masks
return
self
.
bound_model
,
masks
class
LevelPruner
(
OneShot
Pruner
):
class
LevelPruner
(
Basic
Pruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
"""
"""
Parameters
Parameters
...
@@ -154,7 +154,7 @@ class LevelPruner(OneShotPruner):
...
@@ -154,7 +154,7 @@ class LevelPruner(OneShotPruner):
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
)
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
)
class
NormPruner
(
OneShot
Pruner
):
class
NormPruner
(
Basic
Pruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
p
:
int
,
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
p
:
int
,
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
"""
"""
...
@@ -275,7 +275,7 @@ class L2NormPruner(NormPruner):
...
@@ -275,7 +275,7 @@ class L2NormPruner(NormPruner):
super
().
__init__
(
model
,
config_list
,
2
,
mode
,
dummy_input
)
super
().
__init__
(
model
,
config_list
,
2
,
mode
,
dummy_input
)
class
FPGMPruner
(
OneShot
Pruner
):
class
FPGMPruner
(
Basic
Pruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
"""
"""
...
@@ -331,7 +331,7 @@ class FPGMPruner(OneShotPruner):
...
@@ -331,7 +331,7 @@ class FPGMPruner(OneShotPruner):
raise
NotImplementedError
(
'Only support mode `normal` and `dependency_aware`'
)
raise
NotImplementedError
(
'Only support mode `normal` and `dependency_aware`'
)
class
SlimPruner
(
OneShot
Pruner
):
class
SlimPruner
(
Basic
Pruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_epochs
:
int
,
scale
:
float
=
0.0001
,
mode
=
'global'
):
training_epochs
:
int
,
scale
:
float
=
0.0001
,
mode
=
'global'
):
...
@@ -427,7 +427,7 @@ class SlimPruner(OneShotPruner):
...
@@ -427,7 +427,7 @@ class SlimPruner(OneShotPruner):
raise
NotImplementedError
(
'Only support mode `normal` and `global`'
)
raise
NotImplementedError
(
'Only support mode `normal` and `global`'
)
class
ActivationPruner
(
OneShot
Pruner
):
class
ActivationPruner
(
Basic
Pruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_batches
:
int
,
activation
:
str
=
'relu'
,
optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_batches
:
int
,
activation
:
str
=
'relu'
,
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
...
@@ -544,7 +544,7 @@ class ActivationMeanRankPruner(ActivationPruner):
...
@@ -544,7 +544,7 @@ class ActivationMeanRankPruner(ActivationPruner):
return
MeanRankMetricsCalculator
(
dim
=
1
)
return
MeanRankMetricsCalculator
(
dim
=
1
)
class
TaylorFOWeightPruner
(
OneShot
Pruner
):
class
TaylorFOWeightPruner
(
Basic
Pruner
):
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_batches
:
int
,
optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
training_batches
:
int
,
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
...
...
nni/algorithms/compression/v2/pytorch/pruning/basic_scheduler.py
0 → 100644
View file @
e98ebcf0
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
copy
import
deepcopy
from
typing
import
Dict
,
List
,
Tuple
,
Callable
,
Optional
from
torch
import
Tensor
from
torch.nn
import
Module
from
nni.algorithms.compression.v2.pytorch.base
import
Pruner
,
BasePruningScheduler
,
Task
,
TaskResult
from
nni.compression.pytorch.speedup
import
ModelSpeedup
from
.tools
import
TaskGenerator
class
PruningScheduler
(
BasePruningScheduler
):
def
__init__
(
self
,
pruner
:
Pruner
,
task_generator
:
TaskGenerator
,
finetuner
:
Callable
[[
Module
],
None
]
=
None
,
speed_up
:
bool
=
False
,
dummy_input
:
Tensor
=
None
,
evaluator
:
Optional
[
Callable
[[
Module
],
float
]]
=
None
):
"""
Parameters
----------
pruner
The pruner used in pruner scheduler.
The scheduler will use `Pruner.reset(model, config_list)` to reset it in each iteration.
task_generator
Used to generate task for each iteration.
finetuner
The finetuner handled all finetune logic, use a pytorch module as input.
speed_up
If set True, speed up the model in each iteration.
dummy_input
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
evaluator
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
"""
self
.
pruner
=
pruner
self
.
task_generator
=
task_generator
self
.
finetuner
=
finetuner
self
.
speed_up
=
speed_up
self
.
dummy_input
=
dummy_input
self
.
evaluator
=
evaluator
def
generate_task
(
self
)
->
Optional
[
Task
]:
return
self
.
task_generator
.
next
()
def
record_task_result
(
self
,
task_result
:
TaskResult
):
self
.
task_generator
.
receive_task_result
(
task_result
)
def
pruning_one_step
(
self
,
task
:
Task
)
->
TaskResult
:
model
,
masks
,
config_list
=
task
.
load_data
()
# pruning model
self
.
pruner
.
reset
(
model
,
config_list
)
self
.
pruner
.
load_masks
(
masks
)
compact_model
,
pruner_generated_masks
=
self
.
pruner
.
compress
()
compact_model_masks
=
deepcopy
(
pruner_generated_masks
)
# show the pruning effect
self
.
pruner
.
show_pruned_weights
()
self
.
pruner
.
_unwrap_model
()
# speed up
if
self
.
speed_up
:
ModelSpeedup
(
compact_model
,
self
.
dummy_input
,
pruner_generated_masks
).
speedup_model
()
compact_model_masks
=
{}
# finetune
if
self
.
finetuner
is
not
None
:
if
self
.
speed_up
:
self
.
finetuner
(
compact_model
)
else
:
self
.
pruner
.
_wrap_model
()
self
.
finetuner
(
compact_model
)
self
.
pruner
.
_unwrap_model
()
# evaluate
score
=
self
.
evaluator
(
compact_model
)
if
self
.
evaluator
is
not
None
else
None
# clear model references
self
.
pruner
.
clear_model_references
()
return
TaskResult
(
task
.
task_id
,
compact_model
,
compact_model_masks
,
pruner_generated_masks
,
score
)
def
get_best_result
(
self
)
->
Optional
[
Tuple
[
int
,
Module
,
Dict
[
str
,
Dict
[
str
,
Tensor
]],
float
,
List
[
Dict
]]]:
return
self
.
task_generator
.
get_best_result
()
nni/algorithms/compression/v2/pytorch/pruning/tools/__init__.py
View file @
e98ebcf0
...
@@ -2,7 +2,8 @@ from .base import (
...
@@ -2,7 +2,8 @@ from .base import (
HookCollectorInfo
,
HookCollectorInfo
,
DataCollector
,
DataCollector
,
MetricsCalculator
,
MetricsCalculator
,
SparsityAllocator
SparsityAllocator
,
TaskGenerator
)
)
from
.data_collector
import
(
from
.data_collector
import
(
WeightDataCollector
,
WeightDataCollector
,
...
@@ -21,3 +22,7 @@ from .sparsity_allocator import (
...
@@ -21,3 +22,7 @@ from .sparsity_allocator import (
GlobalSparsityAllocator
,
GlobalSparsityAllocator
,
Conv2dDependencyAwareAllocator
Conv2dDependencyAwareAllocator
)
)
from
.task_generator
import
(
AGPTaskGenerator
,
LinearTaskGenerator
)
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
View file @
e98ebcf0
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
datetime
import
datetime
import
logging
import
logging
from
pathlib
import
Path
import
types
import
types
from
typing
import
List
,
Dict
,
Optional
,
Callable
,
Union
from
typing
import
List
,
Dict
,
Tuple
,
Optional
,
Callable
,
Union
import
json_tricks
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
nni.algorithms.compression.v2.pytorch.base
import
Compressor
,
LayerInfo
from
nni.algorithms.compression.v2.pytorch.base
import
Compressor
,
LayerInfo
,
Task
,
TaskResult
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'DataCollector'
,
'TrainerBasedDataCollector'
,
'HookCollectorInfo'
,
'MetricsCalculator'
,
'SparsityAllocator'
]
class
DataCollector
:
class
DataCollector
:
"""
"""
...
@@ -371,7 +372,7 @@ class SparsityAllocator:
...
@@ -371,7 +372,7 @@ class SparsityAllocator:
Returns
Returns
-------
-------
Dict[str, Tensor]
Dict[str, Tensor]
The key is `weight
_mask
` or `bias
_mask
`, value is the final mask.
The key is `weight` or `bias`, value is the final mask.
"""
"""
weight_mask
=
mask
.
clone
()
weight_mask
=
mask
.
clone
()
...
@@ -390,7 +391,7 @@ class SparsityAllocator:
...
@@ -390,7 +391,7 @@ class SparsityAllocator:
if
self
.
dim
is
None
:
if
self
.
dim
is
None
:
assert
weight_mask
.
size
()
==
weight_size
assert
weight_mask
.
size
()
==
weight_size
expand_mask
=
{
'weight
_mask
'
:
weight_mask
}
expand_mask
=
{
'weight'
:
weight_mask
}
else
:
else
:
# expand mask to weight size with dim
# expand mask to weight size with dim
assert
len
(
weight_mask
.
size
())
==
len
(
self
.
dim
)
assert
len
(
weight_mask
.
size
())
==
len
(
self
.
dim
)
...
@@ -400,15 +401,19 @@ class SparsityAllocator:
...
@@ -400,15 +401,19 @@ class SparsityAllocator:
[
idxs
.
pop
(
i
)
for
i
in
reversed
(
self
.
dim
)]
[
idxs
.
pop
(
i
)
for
i
in
reversed
(
self
.
dim
)]
for
i
in
idxs
:
for
i
in
idxs
:
weight_mask
=
weight_mask
.
unsqueeze
(
i
)
weight_mask
=
weight_mask
.
unsqueeze
(
i
)
expand_mask
=
{
'weight
_mask
'
:
weight_mask
.
expand
(
weight_size
).
clone
()}
expand_mask
=
{
'weight'
:
weight_mask
.
expand
(
weight_size
).
clone
()}
# NOTE: assume we only mask output, so the mask and bias have a one-to-one correspondence.
# NOTE: assume we only mask output, so the mask and bias have a one-to-one correspondence.
# If we support more kind of masks, this place need refactor.
# If we support more kind of masks, this place need refactor.
if
wrapper
.
bias_mask
is
not
None
and
weight_mask
.
size
()
==
wrapper
.
bias_mask
.
size
():
if
wrapper
.
bias_mask
is
not
None
and
weight_mask
.
size
()
==
wrapper
.
bias_mask
.
size
():
expand_mask
[
'bias
_mask
'
]
=
weight_mask
.
clone
()
expand_mask
[
'bias'
]
=
weight_mask
.
clone
()
return
expand_mask
return
expand_mask
def
_compress_mask
(
self
,
mask
:
Tensor
)
->
Tensor
:
def
_compress_mask
(
self
,
mask
:
Tensor
)
->
Tensor
:
"""
"""
This function will reduce the mask with `self.dim` and `self.block_sparse_size`.
e.g., a mask tensor with size [50, 60, 70], self.dim is (0, 1), self.block_sparse_size is [10, 10].
Then, the reduced mask size is [50 / 10, 60 / 10] => [5, 6].
Parameters
Parameters
----------
----------
name
name
...
@@ -419,7 +424,7 @@ class SparsityAllocator:
...
@@ -419,7 +424,7 @@ class SparsityAllocator:
Returns
Returns
-------
-------
Tensor
Tensor
Reduce
the
mask
with `self.dim` and `self.block_sparse_size`
.
Reduce
d
mask.
"""
"""
if
self
.
dim
is
None
or
len
(
mask
.
size
())
==
1
:
if
self
.
dim
is
None
or
len
(
mask
.
size
())
==
1
:
mask
=
mask
.
clone
()
mask
=
mask
.
clone
()
...
@@ -440,3 +445,131 @@ class SparsityAllocator:
...
@@ -440,3 +445,131 @@ class SparsityAllocator:
mask
=
torch
.
einsum
(
ein_expression
,
mask
,
torch
.
ones
(
self
.
block_sparse_size
).
to
(
mask
.
device
))
mask
=
torch
.
einsum
(
ein_expression
,
mask
,
torch
.
ones
(
self
.
block_sparse_size
).
to
(
mask
.
device
))
return
(
mask
!=
0
).
type_as
(
mask
)
return
(
mask
!=
0
).
type_as
(
mask
)
class
TaskGenerator
:
"""
This class used to generate config list for pruner in each iteration.
"""
def
__init__
(
self
,
origin_model
:
Module
,
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
origin_config_list
:
List
[
Dict
]
=
[],
log_dir
:
str
=
'.'
,
keep_intermidiate_result
:
bool
=
False
):
"""
Parameters
----------
origin_model
The origin unwrapped pytorch model to be pruned.
origin_masks
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
origin_config_list
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
log_dir
The log directory use to saving the task generator log.
keep_intermidiate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
assert
isinstance
(
origin_model
,
Module
),
'Only support pytorch module.'
self
.
_log_dir_root
=
Path
(
log_dir
,
datetime
.
now
().
strftime
(
'%Y-%m-%d-%H-%M-%S-%f'
)).
absolute
()
self
.
_log_dir_root
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
_keep_intermidiate_result
=
keep_intermidiate_result
self
.
_intermidiate_result_dir
=
Path
(
self
.
_log_dir_root
,
'intermidiate_result'
)
self
.
_intermidiate_result_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# save origin data in {log_dir}/origin
self
.
_origin_model_path
=
Path
(
self
.
_log_dir_root
,
'origin'
,
'model.pth'
)
self
.
_origin_masks_path
=
Path
(
self
.
_log_dir_root
,
'origin'
,
'masks.pth'
)
self
.
_origin_config_list_path
=
Path
(
self
.
_log_dir_root
,
'origin'
,
'config_list.json'
)
self
.
_save_data
(
'origin'
,
origin_model
,
origin_masks
,
origin_config_list
)
self
.
_task_id_candidate
=
0
self
.
_tasks
:
Dict
[
int
,
Task
]
=
{}
self
.
_pending_tasks
:
List
[
Task
]
=
self
.
init_pending_tasks
()
self
.
_best_score
=
None
self
.
_best_task_id
=
None
# dump self._tasks into {log_dir}/.tasks
self
.
_dump_tasks_info
()
def
_dump_tasks_info
(
self
):
tasks
=
{
task_id
:
task
.
to_dict
()
for
task_id
,
task
in
self
.
_tasks
.
items
()}
with
Path
(
self
.
_log_dir_root
,
'.tasks'
).
open
(
'w'
)
as
f
:
json_tricks
.
dump
(
tasks
,
f
,
indent
=
4
)
def
_save_data
(
self
,
folder_name
:
str
,
model
:
Module
,
masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]],
config_list
:
List
[
Dict
]):
Path
(
self
.
_log_dir_root
,
folder_name
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
torch
.
save
(
model
,
Path
(
self
.
_log_dir_root
,
folder_name
,
'model.pth'
))
torch
.
save
(
masks
,
Path
(
self
.
_log_dir_root
,
folder_name
,
'masks.pth'
))
with
Path
(
self
.
_log_dir_root
,
folder_name
,
'config_list.json'
).
open
(
'w'
)
as
f
:
json_tricks
.
dump
(
config_list
,
f
,
indent
=
4
)
def
update_best_result
(
self
,
task_result
:
TaskResult
):
score
=
task_result
.
score
if
score
is
not
None
:
task_id
=
task_result
.
task_id
task
=
self
.
_tasks
[
task_id
]
task
.
score
=
score
if
self
.
_best_score
is
None
or
score
>
self
.
_best_score
:
self
.
_best_score
=
score
self
.
_best_task_id
=
task_id
with
Path
(
task
.
config_list_path
).
open
(
'r'
)
as
fr
:
best_config_list
=
json_tricks
.
load
(
fr
)
self
.
_save_data
(
'best_result'
,
task_result
.
compact_model
,
task_result
.
compact_model_masks
,
best_config_list
)
def
init_pending_tasks
(
self
)
->
List
[
Task
]:
raise
NotImplementedError
()
def
generate_tasks
(
self
,
task_result
:
TaskResult
)
->
List
[
Task
]:
raise
NotImplementedError
()
def
receive_task_result
(
self
,
task_result
:
TaskResult
):
"""
Parameters
----------
task_result
The result of the task.
"""
task_id
=
task_result
.
task_id
assert
task_id
in
self
.
_tasks
,
'Task {} does not exist.'
.
format
(
task_id
)
self
.
update_best_result
(
task_result
)
self
.
_tasks
[
task_id
].
status
=
'Finished'
self
.
_dump_tasks_info
()
self
.
_pending_tasks
.
extend
(
self
.
generate_tasks
(
task_result
))
self
.
_dump_tasks_info
()
if
not
self
.
_keep_intermidiate_result
:
self
.
_tasks
[
task_id
].
clean_up
()
def
next
(
self
)
->
Optional
[
Task
]:
"""
Returns
-------
Optional[Task]
Return the next task from pending tasks.
"""
if
len
(
self
.
_pending_tasks
)
==
0
:
return
None
else
:
task
=
self
.
_pending_tasks
.
pop
(
0
)
task
.
status
=
'Running'
self
.
_dump_tasks_info
()
return
task
def
get_best_result
(
self
)
->
Optional
[
Tuple
[
int
,
Module
,
Dict
[
str
,
Dict
[
str
,
Tensor
]],
float
,
List
[
Dict
]]]:
"""
Returns
-------
Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]
If self._best_task_id is not None,
return best task id, best compact model, masks on the compact model, score, config list used in this task.
"""
if
self
.
_best_task_id
is
not
None
:
compact_model
=
torch
.
load
(
Path
(
self
.
_log_dir_root
,
'best_result'
,
'best_model.pth'
))
compact_model_masks
=
torch
.
load
(
Path
(
self
.
_log_dir_root
,
'best_result'
,
'best_masks.pth'
))
with
Path
(
self
.
_log_dir_root
,
'best_result'
,
'best_config_list.json'
).
open
(
'r'
)
as
f
:
config_list
=
json_tricks
.
load
(
f
)
return
self
.
_best_task_id
,
compact_model
,
compact_model_masks
,
self
.
_best_score
,
config_list
return
None
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
View file @
e98ebcf0
...
@@ -27,8 +27,9 @@ class NormalSparsityAllocator(SparsityAllocator):
...
@@ -27,8 +27,9 @@ class NormalSparsityAllocator(SparsityAllocator):
metric
=
metrics
[
name
]
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
metric
=
metrics
[
name
]
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
prune_num
=
int
(
sparsity_rate
*
metric
.
numel
())
prune_num
=
int
(
sparsity_rate
*
metric
.
numel
())
if
prune_num
==
0
:
if
prune_num
==
0
:
continue
threshold
=
metric
.
min
()
-
1
threshold
=
torch
.
topk
(
metric
.
view
(
-
1
),
prune_num
,
largest
=
False
)[
0
].
max
()
else
:
threshold
=
torch
.
topk
(
metric
.
view
(
-
1
),
prune_num
,
largest
=
False
)[
0
].
max
()
mask
=
torch
.
gt
(
metric
,
threshold
).
type_as
(
metric
)
mask
=
torch
.
gt
(
metric
,
threshold
).
type_as
(
metric
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
return
masks
return
masks
...
@@ -65,19 +66,22 @@ class GlobalSparsityAllocator(SparsityAllocator):
...
@@ -65,19 +66,22 @@ class GlobalSparsityAllocator(SparsityAllocator):
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
name
]
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
name
]
metric
=
metric
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
metric
=
metric
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
layer_weight_num
=
wrapper
.
module
.
weight
.
data
.
numel
()
layer_weight_num
=
wrapper
.
module
.
weight
.
data
.
numel
()
total_weight_num
+=
layer_weight_num
expend_times
=
int
(
layer_weight_num
/
metric
.
numel
())
retention_ratio
=
1
-
max_sparsity_per_layer
.
get
(
name
,
1
)
retention_ratio
=
1
-
max_sparsity_per_layer
.
get
(
name
,
1
)
retention_numel
=
math
.
ceil
(
retention_ratio
*
layer_weight_num
)
retention_numel
=
math
.
ceil
(
retention_ratio
*
layer_weight_num
)
removed_metric_num
=
math
.
ceil
(
retention_numel
/
(
wrapper
.
weight_mask
.
numel
()
/
metric
.
numel
()))
removed_metric_num
=
math
.
ceil
(
retention_numel
/
(
wrapper
.
weight_mask
.
numel
()
/
metric
.
numel
()))
stay_metric_num
=
metric
.
numel
()
-
removed_metric_num
stay_metric_num
=
metric
.
numel
()
-
removed_metric_num
if
stay_metric_num
<=
0
:
sub_thresholds
[
name
]
=
metric
.
min
().
item
()
-
1
continue
# Remove the weight parts that must be left
# Remove the weight parts that must be left
stay_metric
=
torch
.
topk
(
metric
.
view
(
-
1
),
stay_metric_num
,
largest
=
False
)[
0
]
stay_metric
=
torch
.
topk
(
metric
.
view
(
-
1
),
stay_metric_num
,
largest
=
False
)[
0
]
sub_thresholds
[
name
]
=
stay_metric
.
max
()
sub_thresholds
[
name
]
=
stay_metric
.
max
()
expend_times
=
int
(
layer_weight_num
/
metric
.
numel
())
if
expend_times
>
1
:
if
expend_times
>
1
:
stay_metric
=
stay_metric
.
expand
(
stay_metric_num
,
int
(
layer_weight_num
/
metric
.
numel
())).
view
(
-
1
)
stay_metric
=
stay_metric
.
expand
(
stay_metric_num
,
int
(
layer_weight_num
/
metric
.
numel
())).
view
(
-
1
)
metric_list
.
append
(
stay_metric
)
metric_list
.
append
(
stay_metric
)
total_weight_num
+=
layer_weight_num
total_prune_num
=
int
(
total_sparsity
*
total_weight_num
)
total_prune_num
=
int
(
total_sparsity
*
total_weight_num
)
if
total_prune_num
==
0
:
if
total_prune_num
==
0
:
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/task_generator.py
0 → 100644
View file @
e98ebcf0
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
copy
import
deepcopy
import
logging
from
pathlib
import
Path
from
typing
import
Dict
,
List
import
json_tricks
from
torch
import
Tensor
import
torch
from
torch.nn
import
Module
from
nni.algorithms.compression.v2.pytorch.base
import
Task
,
TaskResult
from
nni.algorithms.compression.v2.pytorch.utils.pruning
import
config_list_canonical
,
compute_sparsity
from
.base
import
TaskGenerator
_logger
=
logging
.
getLogger
(
__name__
)
class
FunctionBasedTaskGenerator
(
TaskGenerator
):
def
__init__
(
self
,
total_iteration
:
int
,
origin_model
:
Module
,
origin_config_list
:
List
[
Dict
],
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
log_dir
:
str
=
'.'
,
keep_intermidiate_result
:
bool
=
False
):
self
.
current_iteration
=
0
self
.
target_sparsity
=
config_list_canonical
(
origin_model
,
origin_config_list
)
self
.
total_iteration
=
total_iteration
super
().
__init__
(
origin_model
,
origin_config_list
=
self
.
target_sparsity
,
origin_masks
=
origin_masks
,
log_dir
=
log_dir
,
keep_intermidiate_result
=
keep_intermidiate_result
)
def
init_pending_tasks
(
self
)
->
List
[
Task
]:
origin_model
=
torch
.
load
(
self
.
_origin_model_path
)
origin_masks
=
torch
.
load
(
self
.
_origin_masks_path
)
task_result
=
TaskResult
(
'origin'
,
origin_model
,
origin_masks
,
origin_masks
,
None
)
return
self
.
generate_tasks
(
task_result
)
def
generate_tasks
(
self
,
task_result
:
TaskResult
)
->
List
[
Task
]:
compact_model
=
task_result
.
compact_model
compact_model_masks
=
task_result
.
compact_model_masks
# save intermidiate result
model_path
=
Path
(
self
.
_intermidiate_result_dir
,
'{}_compact_model.pth'
.
format
(
task_result
.
task_id
))
masks_path
=
Path
(
self
.
_intermidiate_result_dir
,
'{}_compact_model_masks.pth'
.
format
(
task_result
.
task_id
))
torch
.
save
(
compact_model
,
model_path
)
torch
.
save
(
compact_model_masks
,
masks_path
)
# get current2origin_sparsity and compact2origin_sparsity
origin_model
=
torch
.
load
(
self
.
_origin_model_path
)
current2origin_sparsity
,
compact2origin_sparsity
,
_
=
compute_sparsity
(
origin_model
,
compact_model
,
compact_model_masks
,
self
.
target_sparsity
)
_logger
.
info
(
'
\n
Task %s total real sparsity compared with original model is:
\n
%s'
,
str
(
task_result
.
task_id
),
json_tricks
.
dumps
(
current2origin_sparsity
,
indent
=
4
))
if
task_result
.
task_id
!=
'origin'
:
self
.
_tasks
[
task_result
.
task_id
].
state
[
'current2origin_sparsity'
]
=
current2origin_sparsity
# if reach the total_iteration, no more task will be generated
if
self
.
current_iteration
>=
self
.
total_iteration
:
return
[]
task_id
=
self
.
_task_id_candidate
new_config_list
=
self
.
generate_config_list
(
self
.
target_sparsity
,
self
.
current_iteration
,
compact2origin_sparsity
)
config_list_path
=
Path
(
self
.
_intermidiate_result_dir
,
'{}_config_list.json'
.
format
(
task_id
))
with
Path
(
config_list_path
).
open
(
'w'
)
as
f
:
json_tricks
.
dump
(
new_config_list
,
f
,
indent
=
4
)
task
=
Task
(
task_id
,
model_path
,
masks_path
,
config_list_path
)
self
.
_tasks
[
task_id
]
=
task
self
.
_task_id_candidate
+=
1
self
.
current_iteration
+=
1
return
[
task
]
def
generate_config_list
(
self
,
target_sparsity
:
List
[
Dict
],
iteration
:
int
,
compact2origin_sparsity
:
List
[
Dict
])
->
List
[
Dict
]:
raise
NotImplementedError
()
class
AGPTaskGenerator
(
FunctionBasedTaskGenerator
):
def
generate_config_list
(
self
,
target_sparsity
:
List
[
Dict
],
iteration
:
int
,
model_based_sparsity
:
List
[
Dict
])
->
List
[
Dict
]:
config_list
=
[]
for
target
,
mo
in
zip
(
target_sparsity
,
model_based_sparsity
):
ori_sparsity
=
(
1
-
(
1
-
iteration
/
self
.
total_iteration
)
**
3
)
*
target
[
'total_sparsity'
]
sparsity
=
max
(
0.0
,
(
ori_sparsity
-
mo
[
'total_sparsity'
])
/
(
1
-
mo
[
'total_sparsity'
]))
assert
0
<=
sparsity
<=
1
,
'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'
.
format
(
sparsity
,
ori_sparsity
,
mo
[
'total_sparsity'
])
config_list
.
append
(
deepcopy
(
target
))
config_list
[
-
1
][
'total_sparsity'
]
=
sparsity
return
config_list
class
LinearTaskGenerator
(
FunctionBasedTaskGenerator
):
def
generate_config_list
(
self
,
target_sparsity
:
List
[
Dict
],
iteration
:
int
,
model_based_sparsity
:
List
[
Dict
])
->
List
[
Dict
]:
config_list
=
[]
for
target
,
mo
in
zip
(
target_sparsity
,
model_based_sparsity
):
ori_sparsity
=
iteration
/
self
.
total_iteration
*
target
[
'total_sparsity'
]
sparsity
=
max
(
0.0
,
(
ori_sparsity
-
mo
[
'total_sparsity'
])
/
(
1
-
mo
[
'total_sparsity'
]))
assert
0
<=
sparsity
<=
1
,
'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'
.
format
(
sparsity
,
ori_sparsity
,
mo
[
'total_sparsity'
])
config_list
.
append
(
deepcopy
(
target
))
config_list
[
-
1
][
'total_sparsity'
]
=
sparsity
return
config_list
nni/algorithms/compression/v2/pytorch/utils/pruning.py
View file @
e98ebcf0
...
@@ -2,8 +2,10 @@
...
@@ -2,8 +2,10 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
from
copy
import
deepcopy
from
copy
import
deepcopy
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
,
Tuple
import
torch
from
torch
import
Tensor
from
torch.nn
import
Module
from
torch.nn
import
Module
...
@@ -82,3 +84,113 @@ def dedupe_config_list(config_list: List[Dict]) -> List[Dict]:
...
@@ -82,3 +84,113 @@ def dedupe_config_list(config_list: List[Dict]) -> List[Dict]:
for
idx
in
sorted
(
exclude_idxes
,
reverse
=
True
):
for
idx
in
sorted
(
exclude_idxes
,
reverse
=
True
):
config_list
.
pop
(
idx
)
config_list
.
pop
(
idx
)
return
config_list
return
config_list
def
compute_sparsity_compact2origin
(
origin_model
:
Module
,
compact_model
:
Module
,
config_list
:
List
[
Dict
])
->
List
[
Dict
]:
"""
Compare origin model and compact model, return the sparsity of each group mentioned in config list.
A group means all layer mentioned in one config.
e.g., a linear named 'linear1' and its weight size is [100, 100] in origin model, but in compact model,
the layer weight size with same layer name is [100, 50],
then this function will return [{'op_names': 'linear1', 'total_sparsity': 0.5}].
"""
compact2origin_sparsity
=
[]
for
config
in
config_list
:
left_weight_num
=
0
total_weight_num
=
0
for
module_name
,
module
in
origin_model
.
named_modules
():
module_type
=
type
(
module
).
__name__
if
'op_types'
in
config
and
module_type
not
in
config
[
'op_types'
]:
continue
if
'op_names'
in
config
and
module_name
not
in
config
[
'op_names'
]:
continue
total_weight_num
+=
module
.
weight
.
data
.
numel
()
for
module_name
,
module
in
compact_model
.
named_modules
():
module_type
=
type
(
module
).
__name__
if
'op_types'
in
config
and
module_type
not
in
config
[
'op_types'
]:
continue
if
'op_names'
in
config
and
module_name
not
in
config
[
'op_names'
]:
continue
left_weight_num
+=
module
.
weight
.
data
.
numel
()
compact2origin_sparsity
.
append
(
deepcopy
(
config
))
compact2origin_sparsity
[
-
1
][
'total_sparsity'
]
=
1
-
left_weight_num
/
total_weight_num
return
compact2origin_sparsity
def
compute_sparsity_mask2compact
(
compact_model
:
Module
,
compact_model_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]],
config_list
:
List
[
Dict
]):
"""
Apply masks on compact model, return the sparsity of each group mentioned in config list.
A group means all layer mentioned in one config.
This function count all zero elements of the masks in one group,
then divide by the elements number of the weights in this group to compute sparsity.
"""
mask2compact_sparsity
=
[]
for
config
in
config_list
:
left_weight_num
=
0
total_weight_num
=
0
for
module_name
,
module
in
compact_model
.
named_modules
():
module_type
=
type
(
module
).
__name__
if
'op_types'
in
config
and
module_type
not
in
config
[
'op_types'
]:
continue
if
'op_names'
in
config
and
module_name
not
in
config
[
'op_names'
]:
continue
module_weight_num
=
module
.
weight
.
data
.
numel
()
total_weight_num
+=
module_weight_num
if
module_name
in
compact_model_masks
:
weight_mask
=
compact_model_masks
[
module_name
][
'weight'
]
left_weight_num
+=
len
(
torch
.
nonzero
(
weight_mask
,
as_tuple
=
False
))
else
:
left_weight_num
+=
module_weight_num
mask2compact_sparsity
.
append
(
deepcopy
(
config
))
mask2compact_sparsity
[
-
1
][
'total_sparsity'
]
=
1
-
left_weight_num
/
total_weight_num
return
mask2compact_sparsity
def
compute_sparsity
(
origin_model
:
Module
,
compact_model
:
Module
,
compact_model_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]],
config_list
:
List
[
Dict
])
->
Tuple
[
List
[
Dict
],
List
[
Dict
],
List
[
Dict
]]:
"""
This function computes how much the origin model has been compressed in the current state.
The current state means `compact_model` + `compact_model_masks`
(i.e., `compact_model_masks` applied on `compact_model`).
The compact model is the origin model after pruning,
and it may have different structure with origin_model cause of speed up.
Returns
-------
Tuple[List[Dict], List[Dict], List[Dict]]
(current2origin_sparsity, compact2origin_sparsity, mask2compact_sparsity).
current2origin_sparsity is how much the origin model has been compressed in the current state.
compact2origin_sparsity is the sparsity obtained by comparing the structure of origin model and compact model.
mask2compact_sparsity is the sparsity computed by count the zero value in the mask.
"""
compact2origin_sparsity
=
compute_sparsity_compact2origin
(
origin_model
,
compact_model
,
config_list
)
mask2compact_sparsity
=
compute_sparsity_mask2compact
(
compact_model
,
compact_model_masks
,
config_list
)
assert
len
(
compact2origin_sparsity
)
==
len
(
mask2compact_sparsity
),
'Length mismatch.'
current2origin_sparsity
=
[]
for
c2o_sparsity
,
m2c_sparsity
,
config
in
zip
(
compact2origin_sparsity
,
mask2compact_sparsity
,
config_list
):
current2origin_sparsity
.
append
(
deepcopy
(
config
))
current2origin_sparsity
[
-
1
][
'total_sparsity'
]
=
1
-
(
1
-
c2o_sparsity
[
'total_sparsity'
])
*
(
1
-
m2c_sparsity
[
'total_sparsity'
])
return
current2origin_sparsity
,
compact2origin_sparsity
,
mask2compact_sparsity
def
get_model_weights_numel
(
model
:
Module
,
config_list
:
List
[
Dict
],
masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{})
->
Dict
:
"""
Count the layer weight elements number in config_list.
If masks is not empty, the masked weight will not be counted.
"""
model_weights_numel
=
{}
masked_rate
=
{}
for
config
in
config_list
:
for
module_name
,
module
in
model
.
named_modules
():
module_type
=
type
(
module
).
__name__
if
'op_types'
in
config
and
module_type
not
in
config
[
'op_types'
]:
continue
if
'op_names'
in
config
and
module_name
not
in
config
[
'op_names'
]:
continue
if
module_name
in
masks
and
isinstance
(
masks
[
module_name
][
'weight'
],
Tensor
):
weight_mask
=
masks
[
module_name
][
'weight'
]
masked_rate
[
module_name
]
=
1
-
(
weight_mask
.
sum
().
item
()
/
weight_mask
.
numel
())
model_weights_numel
[
module_name
]
=
round
(
weight_mask
.
sum
().
item
())
else
:
model_weights_numel
[
module_name
]
=
module
.
weight
.
data
.
numel
()
return
model_weights_numel
,
masked_rate
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment