Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
fd551c86
Commit
fd551c86
authored
Oct 16, 2019
by
Tang Lang
Committed by
QuanluZhang
Oct 16, 2019
Browse files
fix builtin pruners bug (#1612)
* fix builtin pruners bug
parent
d6b61e2f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
23 deletions
+25
-23
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
+25
-23
No files found.
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
View file @
fd551c86
...
@@ -2,7 +2,7 @@ import logging
...
@@ -2,7 +2,7 @@ import logging
import
torch
import
torch
from
.compressor
import
Pruner
from
.compressor
import
Pruner
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'SensitivityPruner'
]
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'SensitivityPruner'
]
logger
=
logging
.
getLogger
(
'torch pruner'
)
logger
=
logging
.
getLogger
(
'torch pruner'
)
...
@@ -10,6 +10,7 @@ logger = logging.getLogger('torch pruner')
...
@@ -10,6 +10,7 @@ logger = logging.getLogger('torch pruner')
class
LevelPruner
(
Pruner
):
class
LevelPruner
(
Pruner
):
"""Prune to an exact pruning level specification
"""Prune to an exact pruning level specification
"""
"""
def
__init__
(
self
,
config_list
):
def
__init__
(
self
,
config_list
):
"""
"""
config_list: supported keys:
config_list: supported keys:
...
@@ -21,9 +22,9 @@ class LevelPruner(Pruner):
...
@@ -21,9 +22,9 @@ class LevelPruner(Pruner):
w_abs
=
weight
.
abs
()
w_abs
=
weight
.
abs
()
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
if
k
==
0
:
if
k
==
0
:
return
torch
.
ones
(
weight
.
shape
)
return
torch
.
ones
(
weight
.
shape
)
.
type_as
(
weight
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
return
torch
.
gt
(
w_abs
,
threshold
).
type
(
weight
.
type
()
)
return
torch
.
gt
(
w_abs
,
threshold
).
type
_as
(
weight
)
class
AGP_Pruner
(
Pruner
):
class
AGP_Pruner
(
Pruner
):
...
@@ -35,12 +36,13 @@ class AGP_Pruner(Pruner):
...
@@ -35,12 +36,13 @@ class AGP_Pruner(Pruner):
Learning of Phones and other Consumer Devices,
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
https://arxiv.org/pdf/1710.01878.pdf
"""
"""
def
__init__
(
self
,
config_list
):
def
__init__
(
self
,
config_list
):
"""
"""
config_list: supported keys:
config_list: supported keys:
- initial_sparsity
- initial_sparsity
- final_sparsity: you should make sure initial_sparsity <= final_sparsity
- final_sparsity: you should make sure initial_sparsity <= final_sparsity
- start_epoch: start epoch numer begin update mask
- start_epoch: start epoch num
b
er begin update mask
- end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch
- end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch
- frequency: if you want update every 2 epoch, you can set it 2
- frequency: if you want update every 2 epoch, you can set it 2
"""
"""
...
@@ -49,15 +51,15 @@ class AGP_Pruner(Pruner):
...
@@ -49,15 +51,15 @@ class AGP_Pruner(Pruner):
self
.
now_epoch
=
1
self
.
now_epoch
=
1
def
calc_mask
(
self
,
weight
,
config
,
op_name
,
**
kwargs
):
def
calc_mask
(
self
,
weight
,
config
,
op_name
,
**
kwargs
):
mask
=
self
.
mask_list
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
))
mask
=
self
.
mask_list
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
)
.
type_as
(
weight
)
)
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
return
mask
return
mask
# if we want to generate new mask, we should update weigth first
# if we want to generate new mask, we should update weigth first
w_abs
=
weight
.
abs
()
*
mask
w_abs
=
weight
.
abs
()
*
mask
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
new_mask
=
torch
.
gt
(
w_abs
,
threshold
).
type
(
weight
.
type
()
)
new_mask
=
torch
.
gt
(
w_abs
,
threshold
).
type
_as
(
weight
)
self
.
mask_list
[
op_name
]
=
new_mask
self
.
mask_list
[
op_name
]
=
new_mask
return
new_mask
return
new_mask
...
@@ -74,11 +76,11 @@ class AGP_Pruner(Pruner):
...
@@ -74,11 +76,11 @@ class AGP_Pruner(Pruner):
if
end_epoch
<=
self
.
now_epoch
:
if
end_epoch
<=
self
.
now_epoch
:
return
final_sparsity
return
final_sparsity
span
=
((
end_epoch
-
start_epoch
-
1
)
//
freq
)
*
freq
span
=
((
end_epoch
-
start_epoch
-
1
)
//
freq
)
*
freq
assert
span
>
0
assert
span
>
0
target_sparsity
=
(
final_sparsity
+
target_sparsity
=
(
final_sparsity
+
(
initial_sparsity
-
final_sparsity
)
*
(
initial_sparsity
-
final_sparsity
)
*
(
1.0
-
((
self
.
now_epoch
-
start_epoch
)
/
span
))
**
3
)
(
1.0
-
((
self
.
now_epoch
-
start_epoch
)
/
span
))
**
3
)
return
target_sparsity
return
target_sparsity
def
update_epoch
(
self
,
epoch
):
def
update_epoch
(
self
,
epoch
):
...
@@ -93,6 +95,7 @@ class SensitivityPruner(Pruner):
...
@@ -93,6 +95,7 @@ class SensitivityPruner(Pruner):
I.e.: "The pruning threshold is chosen as a quality parameter multiplied
I.e.: "The pruning threshold is chosen as a quality parameter multiplied
by the standard deviation of a layers weights."
by the standard deviation of a layers weights."
"""
"""
def
__init__
(
self
,
config_list
):
def
__init__
(
self
,
config_list
):
"""
"""
config_list: supported keys:
config_list: supported keys:
...
@@ -101,18 +104,17 @@ class SensitivityPruner(Pruner):
...
@@ -101,18 +104,17 @@ class SensitivityPruner(Pruner):
super
().
__init__
(
config_list
)
super
().
__init__
(
config_list
)
self
.
mask_list
=
{}
self
.
mask_list
=
{}
def
calc_mask
(
self
,
weight
,
config
,
op_name
,
**
kwargs
):
def
calc_mask
(
self
,
weight
,
config
,
op_name
,
**
kwargs
):
mask
=
self
.
mask_list
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
))
mask
=
self
.
mask_list
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
)
.
type_as
(
weight
)
)
# if we want to generate new mask, we should update weig
t
h first
# if we want to generate new mask, we should update weigh
t
first
weight
=
weight
*
mask
weight
=
weight
*
mask
target_sparsity
=
config
[
'sparsity'
]
*
torch
.
std
(
weight
).
item
()
target_sparsity
=
config
[
'sparsity'
]
*
torch
.
std
(
weight
).
item
()
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
:
if
k
==
0
:
return
mask
return
mask
w_abs
=
weight
.
abs
()
w_abs
=
weight
.
abs
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
new_mask
=
torch
.
gt
(
w_abs
,
threshold
).
type
(
weight
.
type
()
)
new_mask
=
torch
.
gt
(
w_abs
,
threshold
).
type
_as
(
weight
)
self
.
mask_list
[
op_name
]
=
new_mask
self
.
mask_list
[
op_name
]
=
new_mask
return
new_mask
return
new_mask
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment