Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
86f8c2ab
Commit
86f8c2ab
authored
Nov 05, 2019
by
Tang Lang
Committed by
chicm-ms
Nov 05, 2019
Browse files
pruner export (#1674)
parent
6210625b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
9 deletions
+63
-9
examples/model_compress/main_torch_pruner.py
examples/model_compress/main_torch_pruner.py
+3
-1
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
+5
-7
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+55
-1
No files found.
examples/model_compress/main_torch_pruner.py
View file @
86f8c2ab
...
...
@@ -66,6 +66,7 @@ def main():
batch_size
=
1000
,
shuffle
=
True
)
model
=
Mnist
()
model
.
to
(
device
)
'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
...
...
@@ -80,7 +81,7 @@ def main():
}]
pruner
=
AGP_Pruner
(
model
,
configure_list
)
pruner
.
compress
()
model
=
pruner
.
compress
()
# you can also use compress(model) method
# like that pruner.compress(model)
...
...
@@ -90,6 +91,7 @@ def main():
print
(
'# Epoch {} #'
.
format
(
epoch
))
train
(
model
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
pruner
.
export_model
(
'model.pth'
,
'mask.pth'
,
'model.onnx'
,
[
1
,
1
,
28
,
28
])
if
__name__
==
'__main__'
:
...
...
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
View file @
86f8c2ab
...
...
@@ -17,7 +17,6 @@ class LevelPruner(Pruner):
- sparsity
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_list
=
{}
self
.
if_init_list
=
{}
def
calc_mask
(
self
,
layer
,
config
):
...
...
@@ -30,10 +29,10 @@ class LevelPruner(Pruner):
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
self
.
mask_
lis
t
.
update
({
op_name
:
mask
})
self
.
mask_
dic
t
.
update
({
op_name
:
mask
})
self
.
if_init_list
.
update
({
op_name
:
False
})
else
:
mask
=
self
.
mask_
lis
t
[
op_name
]
mask
=
self
.
mask_
dic
t
[
op_name
]
return
mask
...
...
@@ -57,7 +56,6 @@ class AGP_Pruner(Pruner):
- frequency: if you want update every 2 epoch, you can set it 2
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_list
=
{}
self
.
now_epoch
=
0
self
.
if_init_list
=
{}
...
...
@@ -68,7 +66,7 @@ class AGP_Pruner(Pruner):
freq
=
config
.
get
(
'frequency'
,
1
)
if
self
.
now_epoch
>=
start_epoch
and
self
.
if_init_list
.
get
(
op_name
,
True
)
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
:
mask
=
self
.
mask_
lis
t
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
))
mask
=
self
.
mask_
dic
t
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
))
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
...
...
@@ -77,10 +75,10 @@ class AGP_Pruner(Pruner):
w_abs
=
weight
.
abs
()
*
mask
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
new_mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
self
.
mask_
lis
t
.
update
({
op_name
:
new_mask
})
self
.
mask_
dic
t
.
update
({
op_name
:
new_mask
})
self
.
if_init_list
.
update
({
op_name
:
False
})
else
:
new_mask
=
self
.
mask_
lis
t
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
))
new_mask
=
self
.
mask_
dic
t
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
))
return
new_mask
def
compute_target_sparsity
(
self
,
config
):
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
86f8c2ab
...
...
@@ -128,11 +128,23 @@ class Compressor:
expanded_op_types
.
append
(
op_type
)
return
expanded_op_types
class
Pruner
(
Compressor
):
"""
Abstract base PyTorch pruner
Prune to an exact pruning level specification
Attributes
----------
mask_dict : dict
Dictionary for saving masks, `key` should be layer name and
`value` should be a tensor which has the same shape with layer's weight
"""
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
)
self
.
mask_dict
=
{}
def
calc_mask
(
self
,
layer
,
config
):
"""
Pruners should overload this method to provide mask for weight tensors.
...
...
@@ -177,6 +189,48 @@ class Pruner(Compressor):
layer
.
module
.
forward
=
new_forward
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
):
"""
Export pruned model weights, masks and onnx model(optional)
Parameters
----------
model_path : str
path to save pruned model state_dict
mask_path : str
(optional) path to save mask dict
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
"""
assert
model_path
is
not
None
,
'model_path must be specified'
for
name
,
m
in
self
.
bound_model
.
named_modules
():
mask
=
self
.
mask_dict
.
get
(
name
)
if
mask
is
not
None
:
mask_sum
=
mask
.
sum
().
item
()
mask_num
=
mask
.
numel
()
_logger
.
info
(
'Layer: %s Sparsity: %.2f'
,
name
,
1
-
mask_sum
/
mask_num
)
print
(
'Layer: %s Sparsity: %.2f'
%
(
name
,
1
-
mask_sum
/
mask_num
))
m
.
weight
.
data
=
m
.
weight
.
data
.
mul
(
mask
)
else
:
_logger
.
info
(
'Layer: %s NOT compressed'
,
name
)
print
(
'Layer: %s NOT compressed'
%
name
)
torch
.
save
(
self
.
bound_model
.
state_dict
(),
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
print
(
'Model state_dict saved to %s'
%
model_path
)
if
mask_path
is
not
None
:
torch
.
save
(
self
.
mask_dict
,
mask_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
mask_path
)
print
(
'Mask dict saved to %s'
%
mask_path
)
if
onnx_path
is
not
None
:
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
# input info needed
input_data
=
torch
.
Tensor
(
*
input_shape
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
,
onnx_path
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
print
(
'Model in onnx with input shape %s saved to %s'
%
(
input_data
.
shape
,
onnx_path
))
class
Quantizer
(
Compressor
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment