Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
86f8c2ab
You need to sign in or sign up before continuing.
Commit
86f8c2ab
authored
Nov 05, 2019
by
Tang Lang
Committed by
chicm-ms
Nov 05, 2019
Browse files
pruner export (#1674)
parent
6210625b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
9 deletions
+63
-9
examples/model_compress/main_torch_pruner.py
examples/model_compress/main_torch_pruner.py
+3
-1
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
+5
-7
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+55
-1
No files found.
examples/model_compress/main_torch_pruner.py
View file @
86f8c2ab
...
@@ -66,6 +66,7 @@ def main():
...
@@ -66,6 +66,7 @@ def main():
batch_size
=
1000
,
shuffle
=
True
)
batch_size
=
1000
,
shuffle
=
True
)
model
=
Mnist
()
model
=
Mnist
()
model
.
to
(
device
)
'''you can change this to LevelPruner to implement it
'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
pruner = LevelPruner(configure_list)
...
@@ -80,7 +81,7 @@ def main():
...
@@ -80,7 +81,7 @@ def main():
}]
}]
pruner
=
AGP_Pruner
(
model
,
configure_list
)
pruner
=
AGP_Pruner
(
model
,
configure_list
)
pruner
.
compress
()
model
=
pruner
.
compress
()
# you can also use compress(model) method
# you can also use compress(model) method
# like that pruner.compress(model)
# like that pruner.compress(model)
...
@@ -90,6 +91,7 @@ def main():
...
@@ -90,6 +91,7 @@ def main():
print
(
'# Epoch {} #'
.
format
(
epoch
))
print
(
'# Epoch {} #'
.
format
(
epoch
))
train
(
model
,
device
,
train_loader
,
optimizer
)
train
(
model
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
test
(
model
,
device
,
test_loader
)
pruner
.
export_model
(
'model.pth'
,
'mask.pth'
,
'model.onnx'
,
[
1
,
1
,
28
,
28
])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
View file @
86f8c2ab
...
@@ -17,7 +17,6 @@ class LevelPruner(Pruner):
...
@@ -17,7 +17,6 @@ class LevelPruner(Pruner):
- sparsity
- sparsity
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
mask_list
=
{}
self
.
if_init_list
=
{}
self
.
if_init_list
=
{}
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
):
...
@@ -30,10 +29,10 @@ class LevelPruner(Pruner):
...
@@ -30,10 +29,10 @@ class LevelPruner(Pruner):
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
self
.
mask_
lis
t
.
update
({
op_name
:
mask
})
self
.
mask_
dic
t
.
update
({
op_name
:
mask
})
self
.
if_init_list
.
update
({
op_name
:
False
})
self
.
if_init_list
.
update
({
op_name
:
False
})
else
:
else
:
mask
=
self
.
mask_
lis
t
[
op_name
]
mask
=
self
.
mask_
dic
t
[
op_name
]
return
mask
return
mask
...
@@ -57,7 +56,6 @@ class AGP_Pruner(Pruner):
...
@@ -57,7 +56,6 @@ class AGP_Pruner(Pruner):
- frequency: if you want update every 2 epoch, you can set it 2
- frequency: if you want update every 2 epoch, you can set it 2
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
mask_list
=
{}
self
.
now_epoch
=
0
self
.
now_epoch
=
0
self
.
if_init_list
=
{}
self
.
if_init_list
=
{}
...
@@ -68,7 +66,7 @@ class AGP_Pruner(Pruner):
...
@@ -68,7 +66,7 @@ class AGP_Pruner(Pruner):
freq
=
config
.
get
(
'frequency'
,
1
)
freq
=
config
.
get
(
'frequency'
,
1
)
if
self
.
now_epoch
>=
start_epoch
and
self
.
if_init_list
.
get
(
op_name
,
True
)
and
(
if
self
.
now_epoch
>=
start_epoch
and
self
.
if_init_list
.
get
(
op_name
,
True
)
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
:
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
:
mask
=
self
.
mask_
lis
t
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
))
mask
=
self
.
mask_
dic
t
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
))
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
...
@@ -77,10 +75,10 @@ class AGP_Pruner(Pruner):
...
@@ -77,10 +75,10 @@ class AGP_Pruner(Pruner):
w_abs
=
weight
.
abs
()
*
mask
w_abs
=
weight
.
abs
()
*
mask
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
new_mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
new_mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
self
.
mask_
lis
t
.
update
({
op_name
:
new_mask
})
self
.
mask_
dic
t
.
update
({
op_name
:
new_mask
})
self
.
if_init_list
.
update
({
op_name
:
False
})
self
.
if_init_list
.
update
({
op_name
:
False
})
else
:
else
:
new_mask
=
self
.
mask_
lis
t
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
))
new_mask
=
self
.
mask_
dic
t
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
))
return
new_mask
return
new_mask
def
compute_target_sparsity
(
self
,
config
):
def
compute_target_sparsity
(
self
,
config
):
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
86f8c2ab
...
@@ -128,11 +128,23 @@ class Compressor:
...
@@ -128,11 +128,23 @@ class Compressor:
expanded_op_types
.
append
(
op_type
)
expanded_op_types
.
append
(
op_type
)
return
expanded_op_types
return
expanded_op_types
class
Pruner
(
Compressor
):
class
Pruner
(
Compressor
):
"""
"""
Abstract base PyTorch pruner
Prune to an exact pruning level specification
Attributes
----------
mask_dict : dict
Dictionary for saving masks, `key` should be layer name and
`value` should be a tensor which has the same shape with layer's weight
"""
"""
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
)
self
.
mask_dict
=
{}
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
):
"""
"""
Pruners should overload this method to provide mask for weight tensors.
Pruners should overload this method to provide mask for weight tensors.
...
@@ -177,6 +189,48 @@ class Pruner(Compressor):
...
@@ -177,6 +189,48 @@ class Pruner(Compressor):
layer
.
module
.
forward
=
new_forward
layer
.
module
.
forward
=
new_forward
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
):
"""
Export pruned model weights, masks and onnx model(optional)
Parameters
----------
model_path : str
path to save pruned model state_dict
mask_path : str
(optional) path to save mask dict
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
"""
assert
model_path
is
not
None
,
'model_path must be specified'
for
name
,
m
in
self
.
bound_model
.
named_modules
():
mask
=
self
.
mask_dict
.
get
(
name
)
if
mask
is
not
None
:
mask_sum
=
mask
.
sum
().
item
()
mask_num
=
mask
.
numel
()
_logger
.
info
(
'Layer: %s Sparsity: %.2f'
,
name
,
1
-
mask_sum
/
mask_num
)
print
(
'Layer: %s Sparsity: %.2f'
%
(
name
,
1
-
mask_sum
/
mask_num
))
m
.
weight
.
data
=
m
.
weight
.
data
.
mul
(
mask
)
else
:
_logger
.
info
(
'Layer: %s NOT compressed'
,
name
)
print
(
'Layer: %s NOT compressed'
%
name
)
torch
.
save
(
self
.
bound_model
.
state_dict
(),
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
print
(
'Model state_dict saved to %s'
%
model_path
)
if
mask_path
is
not
None
:
torch
.
save
(
self
.
mask_dict
,
mask_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
mask_path
)
print
(
'Mask dict saved to %s'
%
mask_path
)
if
onnx_path
is
not
None
:
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
# input info needed
input_data
=
torch
.
Tensor
(
*
input_shape
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
,
onnx_path
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
print
(
'Model in onnx with input shape %s saved to %s'
%
(
input_data
.
shape
,
onnx_path
))
class
Quantizer
(
Compressor
):
class
Quantizer
(
Compressor
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment