Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
2566badb
Unverified
Commit
2566badb
authored
Mar 23, 2022
by
J-shang
Committed by
GitHub
Mar 23, 2022
Browse files
[Model Compression] Pruning Wrapper Refactor (#4488)
parent
8d5f643c
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
145 additions
and
108 deletions
+145
-108
nni/algorithms/compression/v2/pytorch/base/compressor.py
nni/algorithms/compression/v2/pytorch/base/compressor.py
+1
-8
nni/algorithms/compression/v2/pytorch/base/pruner.py
nni/algorithms/compression/v2/pytorch/base/pruner.py
+82
-10
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+1
-1
nni/algorithms/compression/v2/pytorch/pruning/movement_pruner.py
...orithms/compression/v2/pytorch/pruning/movement_pruner.py
+5
-80
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
+1
-1
nni/algorithms/compression/v2/pytorch/pruning/tools/data_collector.py
...ms/compression/v2/pytorch/pruning/tools/data_collector.py
+2
-2
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
...ompression/v2/pytorch/pruning/tools/sparsity_allocator.py
+1
-1
test/ut/compression/v2/test_pruning_tools_torch.py
test/ut/compression/v2/test_pruning_tools_torch.py
+5
-5
test/ut/compression/v2/test_pruning_wrapper.py
test/ut/compression/v2/test_pruning_wrapper.py
+47
-0
No files found.
nni/algorithms/compression/v2/pytorch/base/compressor.py
View file @
2566badb
...
@@ -257,14 +257,7 @@ class Compressor:
...
@@ -257,14 +257,7 @@ class Compressor:
Dict[str, str]
Dict[str, str]
Return a dict `{original_model_parameter_name: wrapped_model_parameter_name}`
Return a dict `{original_model_parameter_name: wrapped_model_parameter_name}`
"""
"""
if
self
.
is_wrapped
:
raise
NotImplementedError
()
wrapped_param_names
=
{
id
(
param
):
name
for
name
,
param
in
self
.
bound_model
.
named_parameters
()}
self
.
_unwrap_model
()
parameter_name_map
=
{
name
:
wrapped_param_names
[
id
(
param
)]
for
name
,
param
in
self
.
bound_model
.
named_parameters
()}
self
.
_wrap_model
()
return
parameter_name_map
else
:
raise
Exception
(
'When only the model is wrapped can get the parameter_name_map.'
)
def
_wrap_modules
(
self
,
layer
:
LayerInfo
,
config
:
Dict
):
def
_wrap_modules
(
self
,
layer
:
LayerInfo
,
config
:
Dict
):
"""
"""
...
...
nni/algorithms/compression/v2/pytorch/base/pruner.py
View file @
2566badb
...
@@ -6,9 +6,9 @@ from typing import Dict, List, Optional, Tuple
...
@@ -6,9 +6,9 @@ from typing import Dict, List, Optional, Tuple
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.nn
import
Module
from
torch.nn
import
Module
,
Parameter
from
.compressor
import
Compressor
,
LayerInfo
from
.compressor
import
Compressor
,
LayerInfo
,
_setattr
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -27,31 +27,57 @@ class PrunerModuleWrapper(Module):
...
@@ -27,31 +27,57 @@ class PrunerModuleWrapper(Module):
The configurations that users specify for compression.
The configurations that users specify for compression.
module_name
module_name
The name of the module to compress, wrapper module shares same name.
The name of the module to compress, wrapper module shares same name.
pruner
The pruner used to calculate mask.
"""
"""
def
__init__
(
self
,
module
:
Module
,
module_name
:
str
,
config
:
Dict
,
pruner
:
Compressor
):
def
__init__
(
self
,
module
:
Module
,
module_name
:
str
,
config
:
Dict
):
super
().
__init__
()
super
().
__init__
()
# origin layer information
# origin layer information
self
.
module
=
module
self
.
module
=
module
self
.
name
=
module_name
self
.
name
=
module_name
# config
and pruner
# config
information
self
.
config
=
config
self
.
config
=
config
self
.
pruner
=
pruner
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
module
.
weight
.
size
()))
# register buffer for mask
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
module
.
bias
.
size
()))
else
:
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
register_buffer
(
"bias_mask"
,
None
)
def
_weight2buffer
(
self
):
"""
When using this wrapper to inference, call `_weight2buffer()` to make original weight untrainable.
The best place to call this function is in `Pruner._wrap_model()`.
"""
self
.
weight
.
data
=
self
.
module
.
weight
.
data
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
weight
.
data
)
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
bias
.
data
=
self
.
module
.
bias
.
data
delattr
(
self
.
module
,
'bias'
)
self
.
module
.
register_buffer
(
'bias'
,
self
.
bias
.
data
)
def
_weight2parameter
(
self
):
"""
When don't need to record score or need to export the model, call `_weight2parameter()` to make the original weight trainable.
The best place to call this function is in `Pruner._unwrap_model()`.
"""
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
weight
=
Parameter
(
torch
.
empty
(
self
.
weight
.
size
()))
self
.
module
.
weight
.
data
=
torch
.
mul
(
self
.
weight
,
self
.
weight_mask
)
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
delattr
(
self
.
module
,
'bias'
)
self
.
module
.
bias
=
Parameter
(
torch
.
empty
(
self
.
bias
.
size
()))
self
.
module
.
bias
.
data
=
torch
.
mul
(
self
.
bias
,
self
.
bias_mask
)
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
# apply mask to weight, bias
# apply mask to weight, bias
self
.
module
.
weight
.
data
=
self
.
module
.
weight
.
data
.
mul_
(
self
.
weight_mask
)
self
.
module
.
weight
=
torch
.
mul
(
self
.
weight
,
self
.
weight_mask
)
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
module
.
bias
.
data
=
self
.
module
.
bias
.
data
.
mul_
(
self
.
bias_mask
)
self
.
module
.
bias
=
torch
.
mul
(
self
.
bias
,
self
.
bias_mask
)
return
self
.
module
(
*
inputs
)
return
self
.
module
(
*
inputs
)
...
@@ -75,12 +101,58 @@ class Pruner(Compressor):
...
@@ -75,12 +101,58 @@ class Pruner(Compressor):
The configuration for generating the mask.
The configuration for generating the mask.
"""
"""
_logger
.
debug
(
"Module detected to compress : %s."
,
layer
.
name
)
_logger
.
debug
(
"Module detected to compress : %s."
,
layer
.
name
)
wrapper
=
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
config
,
self
)
wrapper
=
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
config
)
assert
hasattr
(
layer
.
module
,
'weight'
),
"module %s does not have 'weight' attribute"
%
layer
.
name
assert
hasattr
(
layer
.
module
,
'weight'
),
"module %s does not have 'weight' attribute"
%
layer
.
name
# move newly registered buffers to the same device of weight
# move newly registered buffers to the same device of weight
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
return
wrapper
return
wrapper
# The following `_wrap_model`, `_unwrap_model`, `get_origin2wrapped_parameter_name_map` can merge to `Compressor`,
# if quantizer use the similar structure wrapper.
def
_wrap_model
(
self
):
"""
Wrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
"""
if
not
self
.
is_wrapped
:
for
_
,
wrapper
in
reversed
(
self
.
get_modules_wrapper
().
items
()):
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
)
wrapper
.
_weight2buffer
()
self
.
is_wrapped
=
True
def
_unwrap_model
(
self
):
"""
Unwrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
"""
if
self
.
is_wrapped
:
for
_
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
.
module
)
wrapper
.
_weight2parameter
()
self
.
is_wrapped
=
False
def
get_origin2wrapped_parameter_name_map
(
self
)
->
Dict
[
str
,
str
]:
"""
Get the name mapping of parameters from original model to wrapped model.
Returns
-------
Dict[str, str]
Return a dict `{original_model_parameter_name: wrapped_model_parameter_name}`
"""
if
self
.
is_wrapped
:
wrapped_param_names
=
{
id
(
param
):
name
for
name
,
param
in
self
.
bound_model
.
named_parameters
()}
self
.
_unwrap_model
()
parameter_name_map
=
{}
for
name
,
param
in
self
.
bound_model
.
named_parameters
():
# If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`, the name will not change after wrap.
# If the parameter name in under wrapped module is others, the name `xxx.param` will change to `xxx.module.param` after wrap.
parameter_name_map
[
name
]
=
wrapped_param_names
[
id
(
param
)]
if
id
(
param
)
in
wrapped_param_names
else
name
self
.
_wrap_model
()
return
parameter_name_map
else
:
raise
Exception
(
'When only the model is wrapped can get the parameter_name_map.'
)
def
load_masks
(
self
,
masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]):
def
load_masks
(
self
,
masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]):
"""
"""
Load an exist masks on the wrapper. You can train the model with an exist masks after load the masks.
Load an exist masks on the wrapper. You can train the model with an exist masks after load the masks.
...
...
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
2566badb
...
@@ -999,7 +999,7 @@ class TaylorFOWeightPruner(BasicPruner):
...
@@ -999,7 +999,7 @@ class TaylorFOWeightPruner(BasicPruner):
return
(
weight_tensor
.
detach
()
*
grad
.
detach
()).
data
.
pow
(
2
)
return
(
weight_tensor
.
detach
()
*
grad
.
detach
()).
data
.
pow
(
2
)
def
reset_tools
(
self
):
def
reset_tools
(
self
):
hook_targets
=
{
layer_info
.
name
:
layer_info
.
module
.
weight
for
layer_info
,
_
in
self
.
_detec
t_modules_
to_compres
s
()}
hook_targets
=
{
name
:
wrapper
.
weight
for
name
,
wrapper
in
self
.
ge
t_modules_
wrapper
().
item
s
()}
collector_info
=
HookCollectorInfo
(
hook_targets
,
'tensor'
,
self
.
_collector
)
collector_info
=
HookCollectorInfo
(
hook_targets
,
'tensor'
,
self
.
_collector
)
if
self
.
data_collector
is
None
:
if
self
.
data_collector
is
None
:
self
.
data_collector
=
SingleHookTrainerBasedDataCollector
(
self
,
self
.
trainer
,
self
.
optimizer_helper
,
self
.
criterion
,
self
.
data_collector
=
SingleHookTrainerBasedDataCollector
(
self
,
self
.
trainer
,
self
.
optimizer_helper
,
self
.
criterion
,
...
...
nni/algorithms/compression/v2/pytorch/pruning/movement_pruner.py
View file @
2566badb
...
@@ -10,7 +10,7 @@ from torch import autograd, Tensor
...
@@ -10,7 +10,7 @@ from torch import autograd, Tensor
from
torch.nn
import
Module
,
Parameter
from
torch.nn
import
Module
,
Parameter
from
torch.optim
import
Optimizer
,
Adam
from
torch.optim
import
Optimizer
,
Adam
from
nni.algorithms.compression.v2.pytorch.base
.compressor
import
Compressor
,
_setatt
r
,
LayerInfo
from
nni.algorithms.compression.v2.pytorch.base
import
PrunerModuleWrappe
r
,
LayerInfo
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
BasicPruner
,
NORMAL_SCHEMA
,
EXCLUDE_SCHEMA
,
INTERNAL_SCHEMA
from
nni.algorithms.compression.v2.pytorch.pruning.basic_pruner
import
BasicPruner
,
NORMAL_SCHEMA
,
EXCLUDE_SCHEMA
,
INTERNAL_SCHEMA
from
nni.algorithms.compression.v2.pytorch.utils
import
CompressorSchema
,
OptimizerConstructHelper
from
nni.algorithms.compression.v2.pytorch.utils
import
CompressorSchema
,
OptimizerConstructHelper
from
nni.common.serializer
import
Traceable
from
nni.common.serializer
import
Traceable
...
@@ -25,7 +25,7 @@ from .tools import (
...
@@ -25,7 +25,7 @@ from .tools import (
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
class
PrunerScoredModuleWrapper
(
Module
):
class
PrunerScoredModuleWrapper
(
PrunerModuleWrapper
):
"""
"""
Wrap a module to enable data parallel, forward method customization and buffer registeration.
Wrap a module to enable data parallel, forward method customization and buffer registeration.
Different from `PrunerModuleWrapper`, `PrunerScoredModuleWrapper` will record the gradient.
Different from `PrunerModuleWrapper`, `PrunerScoredModuleWrapper` will record the gradient.
...
@@ -38,56 +38,12 @@ class PrunerScoredModuleWrapper(Module):
...
@@ -38,56 +38,12 @@ class PrunerScoredModuleWrapper(Module):
The configurations that users specify for compression.
The configurations that users specify for compression.
module_name
module_name
The name of the module to compress, wrapper module shares same name.
The name of the module to compress, wrapper module shares same name.
pruner
The pruner used to calculate mask.
"""
"""
def
__init__
(
self
,
module
:
Module
,
module_name
:
str
,
config
:
Dict
,
pruner
:
Compressor
):
def
__init__
(
self
,
module
:
Module
,
module_name
:
str
,
config
:
Dict
):
super
().
__init__
()
super
().
__init__
(
module
,
module_name
,
config
)
# origin layer information
self
.
module
=
module
self
.
name
=
module_name
# config and pruner
self
.
config
=
config
self
.
pruner
=
pruner
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
module
.
weight
.
size
()))
self
.
weight_score
=
Parameter
(
torch
.
empty
(
self
.
weight
.
size
()))
self
.
weight_score
=
Parameter
(
torch
.
empty
(
self
.
weight
.
size
()))
torch
.
nn
.
init
.
constant_
(
self
.
weight_score
,
val
=
0.0
)
torch
.
nn
.
init
.
constant_
(
self
.
weight_score
,
val
=
0.0
)
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
module
.
bias
.
size
()))
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
def
_weight2buffer
(
self
):
"""
When using this wrapper to inference, call `_weight2buffer()` to make original weight untrainable.
The best place to call this function is in `Pruner._wrap_model()`.
"""
self
.
weight
.
data
=
self
.
module
.
weight
.
data
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
weight
.
data
)
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
bias
.
data
=
self
.
module
.
bias
.
data
delattr
(
self
.
module
,
'bias'
)
self
.
module
.
register_buffer
(
'bias'
,
self
.
bias
.
data
)
def
_weight2parameter
(
self
):
"""
When don't need to record score or need to export the model, call `_weight2parameter()` to make the original weight trainable.
The best place to call this function is in `Pruner._unwrap_model()`.
"""
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
weight
=
Parameter
(
torch
.
empty
(
self
.
weight
.
size
()))
self
.
module
.
weight
.
data
=
torch
.
mul
(
self
.
weight
,
self
.
weight_mask
)
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
delattr
(
self
.
module
,
'bias'
)
self
.
module
.
bias
=
Parameter
(
torch
.
empty
(
self
.
bias
.
size
()))
self
.
module
.
bias
.
data
=
torch
.
mul
(
self
.
bias
,
self
.
bias_mask
)
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
# apply mask to weight, bias
# apply mask to weight, bias
self
.
module
.
weight
=
torch
.
mul
(
self
.
weight
,
_StraightThrough
.
apply
(
self
.
weight_score
,
self
.
weight_mask
))
self
.
module
.
weight
=
torch
.
mul
(
self
.
weight
,
_StraightThrough
.
apply
(
self
.
weight_score
,
self
.
weight_mask
))
...
@@ -259,28 +215,6 @@ class MovementPruner(BasicPruner):
...
@@ -259,28 +215,6 @@ class MovementPruner(BasicPruner):
else
:
else
:
self
.
data_collector
.
reset
()
self
.
data_collector
.
reset
()
def
_wrap_model
(
self
):
"""
Wrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
"""
if
not
self
.
is_wrapped
:
for
_
,
wrapper
in
reversed
(
self
.
get_modules_wrapper
().
items
()):
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
)
wrapper
.
_weight2buffer
()
self
.
is_wrapped
=
True
def
_unwrap_model
(
self
):
"""
Unwrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
"""
if
self
.
is_wrapped
:
for
_
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
.
module
)
wrapper
.
_weight2parameter
()
self
.
is_wrapped
=
False
def
_wrap_modules
(
self
,
layer
:
LayerInfo
,
config
:
Dict
):
def
_wrap_modules
(
self
,
layer
:
LayerInfo
,
config
:
Dict
):
"""
"""
Create a wrapper module to replace the original one.
Create a wrapper module to replace the original one.
...
@@ -294,21 +228,12 @@ class MovementPruner(BasicPruner):
...
@@ -294,21 +228,12 @@ class MovementPruner(BasicPruner):
The configuration for generating the mask.
The configuration for generating the mask.
"""
"""
_logger
.
debug
(
"Module detected to compress : %s."
,
layer
.
name
)
_logger
.
debug
(
"Module detected to compress : %s."
,
layer
.
name
)
wrapper
=
PrunerScoredModuleWrapper
(
layer
.
module
,
layer
.
name
,
config
,
self
)
wrapper
=
PrunerScoredModuleWrapper
(
layer
.
module
,
layer
.
name
,
config
)
assert
hasattr
(
layer
.
module
,
'weight'
),
"module %s does not have 'weight' attribute"
%
layer
.
name
assert
hasattr
(
layer
.
module
,
'weight'
),
"module %s does not have 'weight' attribute"
%
layer
.
name
# move newly registered buffers to the same device of weight
# move newly registered buffers to the same device of weight
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
return
wrapper
return
wrapper
def
get_origin2wrapped_parameter_name_map
(
self
)
->
Dict
[
str
,
str
]:
if
self
.
is_wrapped
:
self
.
_unwrap_model
()
parameter_name_map
=
{
name
:
name
for
name
,
_
in
self
.
bound_model
.
named_parameters
()}
self
.
_wrap_model
()
return
parameter_name_map
else
:
raise
Exception
(
'When only the model is wrapped can get the parameter_name_map.'
)
def
compress
(
self
)
->
Tuple
[
Module
,
Dict
]:
def
compress
(
self
)
->
Tuple
[
Module
,
Dict
]:
# sparsity grow from 0
# sparsity grow from 0
for
_
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
for
_
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
View file @
2566badb
...
@@ -384,7 +384,7 @@ class SparsityAllocator:
...
@@ -384,7 +384,7 @@ class SparsityAllocator:
weight_mask
=
weight_mask
.
expand
(
expand_size
).
reshape
(
reshape_size
)
weight_mask
=
weight_mask
.
expand
(
expand_size
).
reshape
(
reshape_size
)
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
name
]
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
name
]
weight_size
=
wrapper
.
module
.
weight
.
data
.
size
()
weight_size
=
wrapper
.
weight
.
data
.
size
()
if
self
.
dim
is
None
:
if
self
.
dim
is
None
:
assert
weight_mask
.
size
()
==
weight_size
assert
weight_mask
.
size
()
==
weight_size
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/data_collector.py
View file @
2566badb
...
@@ -24,7 +24,7 @@ class WeightDataCollector(DataCollector):
...
@@ -24,7 +24,7 @@ class WeightDataCollector(DataCollector):
def
collect
(
self
)
->
Dict
[
str
,
Tensor
]:
def
collect
(
self
)
->
Dict
[
str
,
Tensor
]:
data
=
{}
data
=
{}
for
_
,
wrapper
in
self
.
compressor
.
get_modules_wrapper
().
items
():
for
_
,
wrapper
in
self
.
compressor
.
get_modules_wrapper
().
items
():
data
[
wrapper
.
name
]
=
wrapper
.
module
.
weight
.
data
data
[
wrapper
.
name
]
=
wrapper
.
weight
.
data
return
data
return
data
...
@@ -39,7 +39,7 @@ class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
...
@@ -39,7 +39,7 @@ class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
data
=
{}
data
=
{}
for
_
,
wrapper
in
self
.
compressor
.
get_modules_wrapper
().
items
():
for
_
,
wrapper
in
self
.
compressor
.
get_modules_wrapper
().
items
():
data
[
wrapper
.
name
]
=
wrapper
.
module
.
weight
.
data
data
[
wrapper
.
name
]
=
wrapper
.
weight
.
data
return
data
return
data
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
View file @
2566badb
...
@@ -132,7 +132,7 @@ class GlobalSparsityAllocator(SparsityAllocator):
...
@@ -132,7 +132,7 @@ class GlobalSparsityAllocator(SparsityAllocator):
if
self
.
continuous_mask
:
if
self
.
continuous_mask
:
metric
=
metric
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
metric
=
metric
*
self
.
_compress_mask
(
wrapper
.
weight_mask
)
layer_weight_num
=
wrapper
.
module
.
weight
.
data
.
numel
()
layer_weight_num
=
wrapper
.
weight
.
data
.
numel
()
total_weight_num
+=
layer_weight_num
total_weight_num
+=
layer_weight_num
expend_times
=
int
(
layer_weight_num
/
metric
.
numel
())
expend_times
=
int
(
layer_weight_num
/
metric
.
numel
())
...
...
test/ut/compression/v2/test_pruning_tools_torch.py
View file @
2566badb
...
@@ -83,17 +83,17 @@ class PruningToolsTestCase(unittest.TestCase):
...
@@ -83,17 +83,17 @@ class PruningToolsTestCase(unittest.TestCase):
# Test WeightDataCollector
# Test WeightDataCollector
data_collector
=
WeightDataCollector
(
pruner
)
data_collector
=
WeightDataCollector
(
pruner
)
data
=
data_collector
.
collect
()
data
=
data_collector
.
collect
()
assert
all
(
torch
.
equal
(
get_module_by_name
(
model
,
module_name
)[
1
].
module
.
weight
.
data
,
data
[
module_name
])
for
module_name
in
[
'conv1'
,
'conv2'
])
assert
all
(
torch
.
equal
(
get_module_by_name
(
model
,
module_name
)[
1
].
weight
.
data
,
data
[
module_name
])
for
module_name
in
[
'conv1'
,
'conv2'
])
# Test WeightTrainerBasedDataCollector
# Test WeightTrainerBasedDataCollector
def
opt_after
():
def
opt_after
():
model
.
conv1
.
module
.
weight
.
data
=
torch
.
ones
(
5
,
1
,
5
,
5
)
model
.
conv1
.
weight
.
data
=
torch
.
ones
(
5
,
1
,
5
,
5
)
model
.
conv2
.
module
.
weight
.
data
=
torch
.
ones
(
10
,
5
,
5
,
5
)
model
.
conv2
.
weight
.
data
=
torch
.
ones
(
10
,
5
,
5
,
5
)
optimizer_helper
=
OptimizerConstructHelper
.
from_trace
(
model
,
get_optimizer
(
model
))
optimizer_helper
=
OptimizerConstructHelper
.
from_trace
(
model
,
get_optimizer
(
model
))
data_collector
=
WeightTrainerBasedDataCollector
(
pruner
,
trainer
,
optimizer_helper
,
criterion
,
1
,
opt_after_tasks
=
[
opt_after
])
data_collector
=
WeightTrainerBasedDataCollector
(
pruner
,
trainer
,
optimizer_helper
,
criterion
,
1
,
opt_after_tasks
=
[
opt_after
])
data
=
data_collector
.
collect
()
data
=
data_collector
.
collect
()
assert
all
(
torch
.
equal
(
get_module_by_name
(
model
,
module_name
)[
1
].
module
.
weight
.
data
,
data
[
module_name
])
for
module_name
in
[
'conv1'
,
'conv2'
])
assert
all
(
torch
.
equal
(
get_module_by_name
(
model
,
module_name
)[
1
].
weight
.
data
,
data
[
module_name
])
for
module_name
in
[
'conv1'
,
'conv2'
])
assert
all
(
t
.
numel
()
==
(
t
==
1
).
type_as
(
t
).
sum
().
item
()
for
t
in
data
.
values
())
assert
all
(
t
.
numel
()
==
(
t
==
1
).
type_as
(
t
).
sum
().
item
()
for
t
in
data
.
values
())
# Test SingleHookTrainerBasedDataCollector
# Test SingleHookTrainerBasedDataCollector
...
@@ -102,7 +102,7 @@ class PruningToolsTestCase(unittest.TestCase):
...
@@ -102,7 +102,7 @@ class PruningToolsTestCase(unittest.TestCase):
if
len
(
buffer
)
<
2
:
if
len
(
buffer
)
<
2
:
buffer
.
append
(
grad
.
clone
().
detach
())
buffer
.
append
(
grad
.
clone
().
detach
())
return
collect_taylor
return
collect_taylor
hook_targets
=
{
'conv1'
:
model
.
conv1
.
module
.
weight
,
'conv2'
:
model
.
conv2
.
module
.
weight
}
hook_targets
=
{
'conv1'
:
model
.
conv1
.
weight
,
'conv2'
:
model
.
conv2
.
weight
}
collector_info
=
HookCollectorInfo
(
hook_targets
,
'tensor'
,
_collector
)
collector_info
=
HookCollectorInfo
(
hook_targets
,
'tensor'
,
_collector
)
optimizer_helper
=
OptimizerConstructHelper
.
from_trace
(
model
,
get_optimizer
(
model
))
optimizer_helper
=
OptimizerConstructHelper
.
from_trace
(
model
,
get_optimizer
(
model
))
...
...
test/ut/compression/v2/test_pruning_wrapper.py
0 → 100644
View file @
2566badb
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
unittest
import
torch
import
torch.nn.functional
as
F
from
nni.algorithms.compression.v2.pytorch.pruning
import
L1NormPruner
class
TorchModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
5
,
5
,
1
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
5
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
5
,
10
,
5
,
1
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
10
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
10
,
100
)
self
.
fc2
=
torch
.
nn
.
Linear
(
100
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
10
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
class
PrunerTestCase
(
unittest
.
TestCase
):
def
test_pruner_module_wrapper
(
self
):
model
=
TorchModel
()
conv1_weight
=
model
.
conv1
.
weight
.
data
.
clone
()
conv2_weight
=
model
.
conv2
.
weight
.
data
.
clone
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
L1NormPruner
(
model
,
config_list
)
_
,
masks
=
pruner
.
compress
()
model
(
torch
.
rand
(
10
,
1
,
28
,
28
))
assert
torch
.
equal
(
model
.
conv1
.
weight
.
data
,
conv1_weight
)
assert
torch
.
equal
(
model
.
conv2
.
weight
.
data
,
conv2_weight
)
assert
torch
.
equal
(
model
.
conv1
.
module
.
weight
.
data
,
conv1_weight
*
masks
[
'conv1'
][
'weight'
])
assert
torch
.
equal
(
model
.
conv2
.
module
.
weight
.
data
,
conv2_weight
*
masks
[
'conv2'
][
'weight'
])
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment