Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
fd551c86
"vscode:/vscode.git/clone" did not exist on "5a19a6c6705fe83db2e3517a2d2f473586901743"
Commit
fd551c86
authored
Oct 16, 2019
by
Tang Lang
Committed by
QuanluZhang
Oct 16, 2019
Browse files
fix builtin pruners bug (#1612)
* fix builtin pruners bug
parent
d6b61e2f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
23 deletions
+25
-23
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
+25
-23
No files found.
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
View file @
fd551c86
...
@@ -2,7 +2,7 @@ import logging
...
@@ -2,7 +2,7 @@ import logging
import
torch
import
torch
from
.compressor
import
Pruner
from
.compressor
import
Pruner
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'SensitivityPruner'
]
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'SensitivityPruner'
]
logger
=
logging
.
getLogger
(
'torch pruner'
)
logger
=
logging
.
getLogger
(
'torch pruner'
)
...
@@ -10,6 +10,7 @@ logger = logging.getLogger('torch pruner')
...
@@ -10,6 +10,7 @@ logger = logging.getLogger('torch pruner')
class
LevelPruner
(
Pruner
):
class
LevelPruner
(
Pruner
):
"""Prune to an exact pruning level specification
"""Prune to an exact pruning level specification
"""
"""
def
__init__
(
self
,
config_list
):
def
__init__
(
self
,
config_list
):
"""
"""
config_list: supported keys:
config_list: supported keys:
...
@@ -21,9 +22,9 @@ class LevelPruner(Pruner):
...
@@ -21,9 +22,9 @@ class LevelPruner(Pruner):
w_abs
=
weight
.
abs
()
w_abs
=
weight
.
abs
()
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
if
k
==
0
:
if
k
==
0
:
return
torch
.
ones
(
weight
.
shape
)
return
torch
.
ones
(
weight
.
shape
)
.
type_as
(
weight
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
return
torch
.
gt
(
w_abs
,
threshold
).
type
(
weight
.
type
()
)
return
torch
.
gt
(
w_abs
,
threshold
).
type
_as
(
weight
)
class
AGP_Pruner
(
Pruner
):
class
AGP_Pruner
(
Pruner
):
...
@@ -35,12 +36,13 @@ class AGP_Pruner(Pruner):
...
@@ -35,12 +36,13 @@ class AGP_Pruner(Pruner):
Learning of Phones and other Consumer Devices,
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
https://arxiv.org/pdf/1710.01878.pdf
"""
"""
def
__init__
(
self
,
config_list
):
def
__init__
(
self
,
config_list
):
"""
"""
config_list: supported keys:
config_list: supported keys:
- initial_sparsity
- initial_sparsity
- final_sparsity: you should make sure initial_sparsity <= final_sparsity
- final_sparsity: you should make sure initial_sparsity <= final_sparsity
- start_epoch: start epoch numer begin update mask
- start_epoch: start epoch num
b
er begin update mask
- end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch
- end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch
- frequency: if you want update every 2 epoch, you can set it 2
- frequency: if you want update every 2 epoch, you can set it 2
"""
"""
...
@@ -49,15 +51,15 @@ class AGP_Pruner(Pruner):
...
@@ -49,15 +51,15 @@ class AGP_Pruner(Pruner):
self
.
now_epoch
=
1
self
.
now_epoch
=
1
def
calc_mask
(
self
,
weight
,
config
,
op_name
,
**
kwargs
):
def
calc_mask
(
self
,
weight
,
config
,
op_name
,
**
kwargs
):
mask
=
self
.
mask_list
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
))
mask
=
self
.
mask_list
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
)
.
type_as
(
weight
)
)
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
return
mask
return
mask
# if we want to generate new mask, we should update weigth first
# if we want to generate new mask, we should update weigth first
w_abs
=
weight
.
abs
()
*
mask
w_abs
=
weight
.
abs
()
*
mask
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
new_mask
=
torch
.
gt
(
w_abs
,
threshold
).
type
(
weight
.
type
()
)
new_mask
=
torch
.
gt
(
w_abs
,
threshold
).
type
_as
(
weight
)
self
.
mask_list
[
op_name
]
=
new_mask
self
.
mask_list
[
op_name
]
=
new_mask
return
new_mask
return
new_mask
...
@@ -74,11 +76,11 @@ class AGP_Pruner(Pruner):
...
@@ -74,11 +76,11 @@ class AGP_Pruner(Pruner):
if
end_epoch
<=
self
.
now_epoch
:
if
end_epoch
<=
self
.
now_epoch
:
return
final_sparsity
return
final_sparsity
span
=
((
end_epoch
-
start_epoch
-
1
)
//
freq
)
*
freq
span
=
((
end_epoch
-
start_epoch
-
1
)
//
freq
)
*
freq
assert
span
>
0
assert
span
>
0
target_sparsity
=
(
final_sparsity
+
target_sparsity
=
(
final_sparsity
+
(
initial_sparsity
-
final_sparsity
)
*
(
initial_sparsity
-
final_sparsity
)
*
(
1.0
-
((
self
.
now_epoch
-
start_epoch
)
/
span
))
**
3
)
(
1.0
-
((
self
.
now_epoch
-
start_epoch
)
/
span
))
**
3
)
return
target_sparsity
return
target_sparsity
def
update_epoch
(
self
,
epoch
):
def
update_epoch
(
self
,
epoch
):
...
@@ -93,6 +95,7 @@ class SensitivityPruner(Pruner):
...
@@ -93,6 +95,7 @@ class SensitivityPruner(Pruner):
I.e.: "The pruning threshold is chosen as a quality parameter multiplied
I.e.: "The pruning threshold is chosen as a quality parameter multiplied
by the standard deviation of a layers weights."
by the standard deviation of a layers weights."
"""
"""
def
__init__
(
self
,
config_list
):
def
__init__
(
self
,
config_list
):
"""
"""
config_list: supported keys:
config_list: supported keys:
...
@@ -101,18 +104,17 @@ class SensitivityPruner(Pruner):
...
@@ -101,18 +104,17 @@ class SensitivityPruner(Pruner):
super
().
__init__
(
config_list
)
super
().
__init__
(
config_list
)
self
.
mask_list
=
{}
self
.
mask_list
=
{}
def
calc_mask
(
self
,
weight
,
config
,
op_name
,
**
kwargs
):
def
calc_mask
(
self
,
weight
,
config
,
op_name
,
**
kwargs
):
mask
=
self
.
mask_list
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
))
mask
=
self
.
mask_list
.
get
(
op_name
,
torch
.
ones
(
weight
.
shape
)
.
type_as
(
weight
)
)
# if we want to generate new mask, we should update weig
t
h first
# if we want to generate new mask, we should update weigh
t
first
weight
=
weight
*
mask
weight
=
weight
*
mask
target_sparsity
=
config
[
'sparsity'
]
*
torch
.
std
(
weight
).
item
()
target_sparsity
=
config
[
'sparsity'
]
*
torch
.
std
(
weight
).
item
()
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
:
if
k
==
0
:
return
mask
return
mask
w_abs
=
weight
.
abs
()
w_abs
=
weight
.
abs
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
new_mask
=
torch
.
gt
(
w_abs
,
threshold
).
type
(
weight
.
type
()
)
new_mask
=
torch
.
gt
(
w_abs
,
threshold
).
type
_as
(
weight
)
self
.
mask_list
[
op_name
]
=
new_mask
self
.
mask_list
[
op_name
]
=
new_mask
return
new_mask
return
new_mask
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment