Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
8b61e774
Unverified
Commit
8b61e774
authored
Sep 22, 2021
by
J-shang
Committed by
GitHub
Sep 22, 2021
Browse files
[Model Compression] admm pruner (#4116)
parent
cdbc0b94
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
119 additions
and
1 deletion
+119
-1
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+119
-1
No files found.
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
8b61e774
...
@@ -135,7 +135,6 @@ class LevelPruner(BasicPruner):
...
@@ -135,7 +135,6 @@ class LevelPruner(BasicPruner):
- op_names : Operation names to prune.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
"""
"""
self
.
mode
=
'normal'
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
def
_validate_config_before_canonical
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
def
_validate_config_before_canonical
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
...
@@ -655,3 +654,122 @@ class TaylorFOWeightPruner(BasicPruner):
...
@@ -655,3 +654,122 @@ class TaylorFOWeightPruner(BasicPruner):
self
.
sparsity_allocator
=
Conv2dDependencyAwareAllocator
(
self
,
0
,
self
.
dummy_input
)
self
.
sparsity_allocator
=
Conv2dDependencyAwareAllocator
(
self
,
0
,
self
.
dummy_input
)
else
:
else
:
raise
NotImplementedError
(
'Only support mode `normal`, `global` and `dependency_aware`'
)
raise
NotImplementedError
(
'Only support mode `normal`, `global` and `dependency_aware`'
)
class
ADMMPruner
(
BasicPruner
):
"""
ADMM (Alternating Direction Method of Multipliers) Pruner is a kind of mathematical optimization technique.
The metric used in this pruner is the absolute value of the weight.
In each iteration, the weight with small magnitudes will be set to zero.
Only in the final iteration, the mask will be generated and apply to model wrapper.
The original paper refer to: https://arxiv.org/abs/1804.03294.
"""
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
],
None
],
optimizer
:
Optimizer
,
criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
],
iterations
:
int
,
training_epochs
:
int
):
"""
Parameters
----------
model
Model to be pruned.
config_list
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- rho : Penalty parameters in ADMM algorithm.
- op_types : Operation types to prune.
- op_names : Operation names to prune.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
trainer
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
optimizer
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data,
so do not use this optimizer in other places.
criterion
The criterion function used in trainer. Take model output and target value as input, and return the loss.
iterations
The total iteration number in admm pruning algorithm.
training_epochs
The epoch number for training model in each iteration.
"""
self
.
trainer
=
trainer
self
.
optimizer
=
optimizer
self
.
criterion
=
criterion
self
.
iterations
=
iterations
self
.
training_epochs
=
training_epochs
super
().
__init__
(
model
,
config_list
)
self
.
Z
=
{
name
:
wrapper
.
module
.
weight
.
data
.
clone
().
detach
()
for
name
,
wrapper
in
self
.
get_modules_wrapper
().
items
()}
self
.
U
=
{
name
:
torch
.
zeros_like
(
z
).
to
(
z
.
device
)
for
name
,
z
in
self
.
Z
.
items
()}
def
_validate_config_before_canonical
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
schema_list
=
[
deepcopy
(
NORMAL_SCHEMA
),
deepcopy
(
INTERNAL_SCHEMA
)]
for
schema
in
schema_list
:
schema
.
update
({
SchemaOptional
(
'rho'
):
And
(
float
,
lambda
n
:
n
>
0
)})
schema_list
.
append
(
deepcopy
(
EXCLUDE_SCHEMA
))
schema
=
CompressorSchema
(
schema_list
,
model
,
_logger
)
schema
.
validate
(
config_list
)
def
criterion_patch
(
self
,
origin_criterion
:
Callable
[[
Tensor
,
Tensor
],
Tensor
]):
def
patched_criterion
(
output
:
Tensor
,
target
:
Tensor
):
penalty
=
torch
.
tensor
(
0.0
).
to
(
output
.
device
)
for
name
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
rho
=
wrapper
.
config
[
'rho'
]
penalty
+=
(
rho
/
2
)
*
torch
.
sqrt
(
torch
.
norm
(
wrapper
.
module
.
weight
-
self
.
Z
[
name
]
+
self
.
U
[
name
]))
return
origin_criterion
(
output
,
target
)
+
penalty
return
patched_criterion
def
reset_tools
(
self
):
if
self
.
data_collector
is
None
:
self
.
data_collector
=
WeightTrainerBasedDataCollector
(
self
,
self
.
trainer
,
self
.
optimizer
,
self
.
criterion
,
self
.
training_epochs
,
criterion_patch
=
self
.
criterion_patch
)
else
:
self
.
data_collector
.
reset
()
if
self
.
metrics_calculator
is
None
:
self
.
metrics_calculator
=
NormMetricsCalculator
()
if
self
.
sparsity_allocator
is
None
:
self
.
sparsity_allocator
=
NormalSparsityAllocator
(
self
)
def
compress
(
self
)
->
Tuple
[
Module
,
Dict
]:
"""
Returns
-------
Tuple[Module, Dict]
Return the wrapped model and mask.
"""
for
i
in
range
(
self
.
iterations
):
_logger
.
info
(
'======= ADMM Iteration %d Start ======='
,
i
)
data
=
self
.
data_collector
.
collect
()
for
name
,
weight
in
data
.
items
():
self
.
Z
[
name
]
=
weight
+
self
.
U
[
name
]
metrics
=
self
.
metrics_calculator
.
calculate_metrics
(
self
.
Z
)
masks
=
self
.
sparsity_allocator
.
generate_sparsity
(
metrics
)
for
name
,
mask
in
masks
.
items
():
self
.
Z
[
name
]
=
self
.
Z
[
name
].
mul
(
mask
[
'weight'
])
self
.
U
[
name
]
=
self
.
U
[
name
]
+
data
[
name
]
-
self
.
Z
[
name
]
metrics
=
self
.
metrics_calculator
.
calculate_metrics
(
data
)
masks
=
self
.
sparsity_allocator
.
generate_sparsity
(
metrics
)
self
.
load_masks
(
masks
)
return
self
.
bound_model
,
masks
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment