Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d452a166
Unverified
Commit
d452a166
authored
Feb 05, 2020
by
QuanluZhang
Committed by
GitHub
Feb 05, 2020
Browse files
update lottery ticket pruner based on refactored compression code (#1989)
parent
6b0ecee6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
56 additions
and
31 deletions
+56
-31
examples/model_compress/lottery_torch_mnist_fc.py
examples/model_compress/lottery_torch_mnist_fc.py
+2
-0
examples/model_compress/multi_gpu.py
examples/model_compress/multi_gpu.py
+1
-1
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+30
-4
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+23
-26
No files found.
examples/model_compress/lottery_torch_mnist_fc.py
View file @
d452a166
...
...
@@ -71,6 +71,8 @@ if __name__ == '__main__':
pruner
=
LotteryTicketPruner
(
model
,
configure_list
,
optimizer
)
pruner
.
compress
()
#model = nn.DataParallel(model)
for
i
in
pruner
.
get_prune_iterations
():
pruner
.
prune_iteration_start
()
loss
=
0
...
...
examples/model_compress/multi_gpu.py
View file @
d452a166
...
...
@@ -69,7 +69,7 @@ if __name__ == '__main__':
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
traindataset
,
batch_size
=
60
,
shuffle
=
True
,
num_workers
=
10
,
drop_last
=
False
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
testdataset
,
batch_size
=
60
,
shuffle
=
False
,
num_workers
=
10
,
drop_last
=
True
)
device
=
torch
.
device
(
"cuda
: 0
"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
model
=
fc1
()
criterion
=
nn
.
CrossEntropyLoss
()
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
d452a166
...
...
@@ -41,6 +41,7 @@ class Compressor:
self
.
modules_to_compress
=
None
self
.
modules_wrapper
=
None
self
.
buffers
=
{}
self
.
is_wrapped
=
False
def
detect_modules_to_compress
(
self
):
"""
...
...
@@ -63,6 +64,7 @@ class Compressor:
"""
for
wrapper
in
reversed
(
self
.
get_modules_wrapper
()):
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
)
self
.
is_wrapped
=
True
def
_unwrap_model
(
self
):
"""
...
...
@@ -71,6 +73,7 @@ class Compressor:
"""
for
wrapper
in
self
.
get_modules_wrapper
():
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
.
module
)
self
.
is_wrapped
=
False
def
compress
(
self
):
"""
...
...
@@ -263,7 +266,7 @@ class Pruner(Compressor):
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
)
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight.
...
...
@@ -291,9 +294,12 @@ class Pruner(Compressor):
the configuration for generating the mask
"""
_logger
.
info
(
"compressing module %s."
,
layer
.
name
)
return
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
wrapper
=
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
assert
hasattr
(
layer
.
module
,
'weight'
)
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
return
wrapper
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
):
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
Export pruned model weights, masks and onnx model(optional)
...
...
@@ -307,6 +313,9 @@ class Pruner(Compressor):
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
# if self.detect_modules_to_compress() and not self.mask_dict:
# _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
...
...
@@ -335,12 +344,29 @@ class Pruner(Compressor):
if
onnx_path
is
not
None
:
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
# input info needed
if
device
is
None
:
device
=
torch
.
device
(
'cpu'
)
input_data
=
torch
.
Tensor
(
*
input_shape
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
,
onnx_path
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
.
to
(
device
)
,
onnx_path
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
self
.
_wrap_model
()
def
load_model_state_dict
(
self
,
model_state
):
"""
Load the state dict saved from unwrapped model.
Parameters:
-----------
model_state : dict
state dict saved from unwrapped model
"""
if
self
.
is_wrapped
:
self
.
_unwrap_model
()
self
.
bound_model
.
load_state_dict
(
model_state
)
self
.
_wrap_model
()
else
:
self
.
bound_model
.
load_state_dict
(
model_state
)
class
QuantizerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
quantizer
):
...
...
src/sdk/pynni/nni/compression/torch/pruners.py
View file @
d452a166
...
...
@@ -290,38 +290,23 @@ class LotteryTicketPruner(Pruner):
prune_iterations
=
config
[
'prune_iterations'
]
return
prune_iterations
def
_print_masks
(
self
,
print_mask
=
False
):
torch
.
set_printoptions
(
threshold
=
1000
)
for
op_name
in
self
.
mask_dict
.
keys
():
mask
=
self
.
mask_dict
[
op_name
]
print
(
'op name: '
,
op_name
)
if
print_mask
:
print
(
'mask: '
,
mask
)
# calculate current sparsity
mask_num
=
mask
[
'weight'
].
sum
().
item
()
mask_size
=
mask
[
'weight'
].
numel
()
print
(
'sparsity: '
,
1
-
mask_num
/
mask_size
)
torch
.
set_printoptions
(
profile
=
'default'
)
def
_calc_sparsity
(
self
,
sparsity
):
keep_ratio_once
=
(
1
-
sparsity
)
**
(
1
/
self
.
prune_iterations
)
curr_keep_ratio
=
keep_ratio_once
**
self
.
curr_prune_iteration
return
max
(
1
-
curr_keep_ratio
,
0
)
def
_calc_mask
(
self
,
weight
,
sparsity
,
op_name
):
def
_calc_mask
(
self
,
weight
,
sparsity
,
curr_w_mask
):
if
self
.
curr_prune_iteration
==
0
:
mask
=
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
else
:
curr_sparsity
=
self
.
_calc_sparsity
(
sparsity
)
assert
self
.
mask_dict
.
get
(
op_name
)
is
not
None
curr_mask
=
self
.
mask_dict
.
get
(
op_name
)
w_abs
=
weight
.
abs
()
*
curr_mask
[
'weight'
]
w_abs
=
weight
.
abs
()
*
curr_w_mask
k
=
int
(
w_abs
.
numel
()
*
curr_sparsity
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
return
{
'weight'
:
mask
}
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Generate mask for the given ``weight``.
...
...
@@ -331,15 +316,17 @@ class LotteryTicketPruner(Pruner):
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
-------
tensor
The mask for this weight
The mask for this weight, it is ```None``` because this pruner
calculates and assigns masks in ```prune_iteration_start```,
no need to do anything in this function.
"""
assert
self
.
mask_dict
.
get
(
layer
.
name
)
is
not
None
,
'Please call iteration_start before training'
mask
=
self
.
mask_dict
[
layer
.
name
]
return
mask
return
None
def
get_prune_iterations
(
self
):
"""
...
...
@@ -364,16 +351,26 @@ class LotteryTicketPruner(Pruner):
self
.
curr_prune_iteration
+=
1
assert
self
.
curr_prune_iteration
<
self
.
prune_iterations
+
1
,
'Exceed the configured prune_iterations'
modules_wrapper
=
self
.
get_modules_wrapper
()
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
module_wrapper
=
None
for
wrapper
in
modules_wrapper
:
if
wrapper
.
name
==
layer
.
name
:
module_wrapper
=
wrapper
break
assert
module_wrapper
is
not
None
sparsity
=
config
.
get
(
'sparsity'
)
mask
=
self
.
_calc_mask
(
layer
.
module
.
weight
.
data
,
sparsity
,
layer
.
name
)
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
_print_masks
()
mask
=
self
.
_calc_mask
(
layer
.
module
.
weight
.
data
,
sparsity
,
module_wrapper
.
weight_mask
)
# TODO: directly use weight_mask is not good
module_wrapper
.
weight_mask
.
copy_
(
mask
[
'weight'
])
# there is no mask for bias
# reinit weights back to original after new masks are generated
if
self
.
reset_weights
:
self
.
_model
.
load_state_dict
(
self
.
_model_state
)
# should use this member function to reset model weights
self
.
load_model_state_dict
(
self
.
_model_state
)
self
.
_optimizer
.
load_state_dict
(
self
.
_optimizer_state
)
if
self
.
_lr_scheduler
is
not
None
:
self
.
_lr_scheduler
.
load_state_dict
(
self
.
_scheduler_state
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment