Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d452a166
Unverified
Commit
d452a166
authored
Feb 05, 2020
by
QuanluZhang
Committed by
GitHub
Feb 05, 2020
Browse files
update lottery ticket pruner based on refactored compression code (#1989)
parent
6b0ecee6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
56 additions
and
31 deletions
+56
-31
examples/model_compress/lottery_torch_mnist_fc.py
examples/model_compress/lottery_torch_mnist_fc.py
+2
-0
examples/model_compress/multi_gpu.py
examples/model_compress/multi_gpu.py
+1
-1
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+30
-4
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+23
-26
No files found.
examples/model_compress/lottery_torch_mnist_fc.py
View file @
d452a166
...
@@ -71,6 +71,8 @@ if __name__ == '__main__':
...
@@ -71,6 +71,8 @@ if __name__ == '__main__':
pruner
=
LotteryTicketPruner
(
model
,
configure_list
,
optimizer
)
pruner
=
LotteryTicketPruner
(
model
,
configure_list
,
optimizer
)
pruner
.
compress
()
pruner
.
compress
()
#model = nn.DataParallel(model)
for
i
in
pruner
.
get_prune_iterations
():
for
i
in
pruner
.
get_prune_iterations
():
pruner
.
prune_iteration_start
()
pruner
.
prune_iteration_start
()
loss
=
0
loss
=
0
...
...
examples/model_compress/multi_gpu.py
View file @
d452a166
...
@@ -69,7 +69,7 @@ if __name__ == '__main__':
...
@@ -69,7 +69,7 @@ if __name__ == '__main__':
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
traindataset
,
batch_size
=
60
,
shuffle
=
True
,
num_workers
=
10
,
drop_last
=
False
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
traindataset
,
batch_size
=
60
,
shuffle
=
True
,
num_workers
=
10
,
drop_last
=
False
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
testdataset
,
batch_size
=
60
,
shuffle
=
False
,
num_workers
=
10
,
drop_last
=
True
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
testdataset
,
batch_size
=
60
,
shuffle
=
False
,
num_workers
=
10
,
drop_last
=
True
)
device
=
torch
.
device
(
"cuda
: 0
"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
model
=
fc1
()
model
=
fc1
()
criterion
=
nn
.
CrossEntropyLoss
()
criterion
=
nn
.
CrossEntropyLoss
()
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
d452a166
...
@@ -41,6 +41,7 @@ class Compressor:
...
@@ -41,6 +41,7 @@ class Compressor:
self
.
modules_to_compress
=
None
self
.
modules_to_compress
=
None
self
.
modules_wrapper
=
None
self
.
modules_wrapper
=
None
self
.
buffers
=
{}
self
.
buffers
=
{}
self
.
is_wrapped
=
False
def
detect_modules_to_compress
(
self
):
def
detect_modules_to_compress
(
self
):
"""
"""
...
@@ -63,6 +64,7 @@ class Compressor:
...
@@ -63,6 +64,7 @@ class Compressor:
"""
"""
for
wrapper
in
reversed
(
self
.
get_modules_wrapper
()):
for
wrapper
in
reversed
(
self
.
get_modules_wrapper
()):
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
)
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
)
self
.
is_wrapped
=
True
def
_unwrap_model
(
self
):
def
_unwrap_model
(
self
):
"""
"""
...
@@ -71,6 +73,7 @@ class Compressor:
...
@@ -71,6 +73,7 @@ class Compressor:
"""
"""
for
wrapper
in
self
.
get_modules_wrapper
():
for
wrapper
in
self
.
get_modules_wrapper
():
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
.
module
)
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
.
module
)
self
.
is_wrapped
=
False
def
compress
(
self
):
def
compress
(
self
):
"""
"""
...
@@ -263,7 +266,7 @@ class Pruner(Compressor):
...
@@ -263,7 +266,7 @@ class Pruner(Compressor):
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
"""
Pruners should overload this method to provide mask for weight tensors.
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight.
The mask must have the same shape and type comparing to the weight.
...
@@ -291,9 +294,12 @@ class Pruner(Compressor):
...
@@ -291,9 +294,12 @@ class Pruner(Compressor):
the configuration for generating the mask
the configuration for generating the mask
"""
"""
_logger
.
info
(
"compressing module %s."
,
layer
.
name
)
_logger
.
info
(
"compressing module %s."
,
layer
.
name
)
return
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
wrapper
=
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
assert
hasattr
(
layer
.
module
,
'weight'
)
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
return
wrapper
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
):
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
"""
Export pruned model weights, masks and onnx model(optional)
Export pruned model weights, masks and onnx model(optional)
...
@@ -307,6 +313,9 @@ class Pruner(Compressor):
...
@@ -307,6 +313,9 @@ class Pruner(Compressor):
(optional) path to save onnx model
(optional) path to save onnx model
input_shape : list or tuple
input_shape : list or tuple
input shape to onnx model
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
"""
# if self.detect_modules_to_compress() and not self.mask_dict:
# if self.detect_modules_to_compress() and not self.mask_dict:
# _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
# _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
...
@@ -335,12 +344,29 @@ class Pruner(Compressor):
...
@@ -335,12 +344,29 @@ class Pruner(Compressor):
if
onnx_path
is
not
None
:
if
onnx_path
is
not
None
:
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
# input info needed
# input info needed
if
device
is
None
:
device
=
torch
.
device
(
'cpu'
)
input_data
=
torch
.
Tensor
(
*
input_shape
)
input_data
=
torch
.
Tensor
(
*
input_shape
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
,
onnx_path
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
.
to
(
device
)
,
onnx_path
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
self
.
_wrap_model
()
self
.
_wrap_model
()
def
load_model_state_dict
(
self
,
model_state
):
"""
Load the state dict saved from unwrapped model.
Parameters:
-----------
model_state : dict
state dict saved from unwrapped model
"""
if
self
.
is_wrapped
:
self
.
_unwrap_model
()
self
.
bound_model
.
load_state_dict
(
model_state
)
self
.
_wrap_model
()
else
:
self
.
bound_model
.
load_state_dict
(
model_state
)
class
QuantizerModuleWrapper
(
torch
.
nn
.
Module
):
class
QuantizerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
quantizer
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
quantizer
):
...
...
src/sdk/pynni/nni/compression/torch/pruners.py
View file @
d452a166
...
@@ -290,38 +290,23 @@ class LotteryTicketPruner(Pruner):
...
@@ -290,38 +290,23 @@ class LotteryTicketPruner(Pruner):
prune_iterations
=
config
[
'prune_iterations'
]
prune_iterations
=
config
[
'prune_iterations'
]
return
prune_iterations
return
prune_iterations
def
_print_masks
(
self
,
print_mask
=
False
):
torch
.
set_printoptions
(
threshold
=
1000
)
for
op_name
in
self
.
mask_dict
.
keys
():
mask
=
self
.
mask_dict
[
op_name
]
print
(
'op name: '
,
op_name
)
if
print_mask
:
print
(
'mask: '
,
mask
)
# calculate current sparsity
mask_num
=
mask
[
'weight'
].
sum
().
item
()
mask_size
=
mask
[
'weight'
].
numel
()
print
(
'sparsity: '
,
1
-
mask_num
/
mask_size
)
torch
.
set_printoptions
(
profile
=
'default'
)
def
_calc_sparsity
(
self
,
sparsity
):
def
_calc_sparsity
(
self
,
sparsity
):
keep_ratio_once
=
(
1
-
sparsity
)
**
(
1
/
self
.
prune_iterations
)
keep_ratio_once
=
(
1
-
sparsity
)
**
(
1
/
self
.
prune_iterations
)
curr_keep_ratio
=
keep_ratio_once
**
self
.
curr_prune_iteration
curr_keep_ratio
=
keep_ratio_once
**
self
.
curr_prune_iteration
return
max
(
1
-
curr_keep_ratio
,
0
)
return
max
(
1
-
curr_keep_ratio
,
0
)
def
_calc_mask
(
self
,
weight
,
sparsity
,
op_name
):
def
_calc_mask
(
self
,
weight
,
sparsity
,
curr_w_mask
):
if
self
.
curr_prune_iteration
==
0
:
if
self
.
curr_prune_iteration
==
0
:
mask
=
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
mask
=
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
else
:
else
:
curr_sparsity
=
self
.
_calc_sparsity
(
sparsity
)
curr_sparsity
=
self
.
_calc_sparsity
(
sparsity
)
assert
self
.
mask_dict
.
get
(
op_name
)
is
not
None
w_abs
=
weight
.
abs
()
*
curr_w_mask
curr_mask
=
self
.
mask_dict
.
get
(
op_name
)
w_abs
=
weight
.
abs
()
*
curr_mask
[
'weight'
]
k
=
int
(
w_abs
.
numel
()
*
curr_sparsity
)
k
=
int
(
w_abs
.
numel
()
*
curr_sparsity
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
return
{
'weight'
:
mask
}
return
{
'weight'
:
mask
}
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
"""
Generate mask for the given ``weight``.
Generate mask for the given ``weight``.
...
@@ -331,15 +316,17 @@ class LotteryTicketPruner(Pruner):
...
@@ -331,15 +316,17 @@ class LotteryTicketPruner(Pruner):
The layer to be pruned
The layer to be pruned
config : dict
config : dict
Pruning configurations for this weight
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
Returns
-------
-------
tensor
tensor
The mask for this weight
The mask for this weight, it is ```None``` because this pruner
calculates and assigns masks in ```prune_iteration_start```,
no need to do anything in this function.
"""
"""
assert
self
.
mask_dict
.
get
(
layer
.
name
)
is
not
None
,
'Please call iteration_start before training'
return
None
mask
=
self
.
mask_dict
[
layer
.
name
]
return
mask
def
get_prune_iterations
(
self
):
def
get_prune_iterations
(
self
):
"""
"""
...
@@ -364,16 +351,26 @@ class LotteryTicketPruner(Pruner):
...
@@ -364,16 +351,26 @@ class LotteryTicketPruner(Pruner):
self
.
curr_prune_iteration
+=
1
self
.
curr_prune_iteration
+=
1
assert
self
.
curr_prune_iteration
<
self
.
prune_iterations
+
1
,
'Exceed the configured prune_iterations'
assert
self
.
curr_prune_iteration
<
self
.
prune_iterations
+
1
,
'Exceed the configured prune_iterations'
modules_wrapper
=
self
.
get_modules_wrapper
()
modules_to_compress
=
self
.
detect_modules_to_compress
()
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
module_wrapper
=
None
for
wrapper
in
modules_wrapper
:
if
wrapper
.
name
==
layer
.
name
:
module_wrapper
=
wrapper
break
assert
module_wrapper
is
not
None
sparsity
=
config
.
get
(
'sparsity'
)
sparsity
=
config
.
get
(
'sparsity'
)
mask
=
self
.
_calc_mask
(
layer
.
module
.
weight
.
data
,
sparsity
,
layer
.
name
)
mask
=
self
.
_calc_mask
(
layer
.
module
.
weight
.
data
,
sparsity
,
module_wrapper
.
weight_mask
)
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
# TODO: directly use weight_mask is not good
self
.
_print_masks
()
module_wrapper
.
weight_mask
.
copy_
(
mask
[
'weight'
])
# there is no mask for bias
# reinit weights back to original after new masks are generated
# reinit weights back to original after new masks are generated
if
self
.
reset_weights
:
if
self
.
reset_weights
:
self
.
_model
.
load_state_dict
(
self
.
_model_state
)
# should use this member function to reset model weights
self
.
load_model_state_dict
(
self
.
_model_state
)
self
.
_optimizer
.
load_state_dict
(
self
.
_optimizer_state
)
self
.
_optimizer
.
load_state_dict
(
self
.
_optimizer_state
)
if
self
.
_lr_scheduler
is
not
None
:
if
self
.
_lr_scheduler
is
not
None
:
self
.
_lr_scheduler
.
load_state_dict
(
self
.
_scheduler_state
)
self
.
_lr_scheduler
.
load_state_dict
(
self
.
_scheduler_state
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment