Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5eb95c2d
Unverified
Commit
5eb95c2d
authored
Mar 17, 2020
by
chicm-ms
Committed by
GitHub
Mar 17, 2020
Browse files
Fix pruners (#2153)
parent
e9f54647
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
166 additions
and
4 deletions
+166
-4
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
...pynni/nni/compression/torch/weight_rank_filter_pruners.py
+4
-4
src/sdk/pynni/tests/test_pruners.py
src/sdk/pynni/tests/test_pruners.py
+162
-0
No files found.
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
View file @
5eb95c2d
...
@@ -124,9 +124,9 @@ class L1FilterPruner(WeightRankFilterPruner):
...
@@ -124,9 +124,9 @@ class L1FilterPruner(WeightRankFilterPruner):
w_abs_structured
=
w_abs
.
view
(
filters
,
-
1
).
sum
(
dim
=
1
)
w_abs_structured
=
w_abs
.
view
(
filters
,
-
1
).
sum
(
dim
=
1
)
threshold
=
torch
.
topk
(
w_abs_structured
.
view
(
-
1
),
num_prune
,
largest
=
False
)[
0
].
max
()
threshold
=
torch
.
topk
(
w_abs_structured
.
view
(
-
1
),
num_prune
,
largest
=
False
)[
0
].
max
()
mask_weight
=
torch
.
gt
(
w_abs_structured
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
mask_weight
=
torch
.
gt
(
w_abs_structured
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
mask_bias
=
torch
.
gt
(
w_abs_structured
,
threshold
).
type_as
(
weight
)
mask_bias
=
torch
.
gt
(
w_abs_structured
,
threshold
).
type_as
(
weight
)
.
detach
()
if
base_mask
[
'bias_mask'
]
is
not
None
else
None
return
{
'weight_mask'
:
mask_weight
.
detach
(),
'bias_mask'
:
mask_bias
.
detach
()
}
return
{
'weight_mask'
:
mask_weight
.
detach
(),
'bias_mask'
:
mask_bias
}
class
L2FilterPruner
(
WeightRankFilterPruner
):
class
L2FilterPruner
(
WeightRankFilterPruner
):
...
@@ -172,9 +172,9 @@ class L2FilterPruner(WeightRankFilterPruner):
...
@@ -172,9 +172,9 @@ class L2FilterPruner(WeightRankFilterPruner):
w_l2_norm
=
torch
.
sqrt
((
w
**
2
).
sum
(
dim
=
1
))
w_l2_norm
=
torch
.
sqrt
((
w
**
2
).
sum
(
dim
=
1
))
threshold
=
torch
.
topk
(
w_l2_norm
.
view
(
-
1
),
num_prune
,
largest
=
False
)[
0
].
max
()
threshold
=
torch
.
topk
(
w_l2_norm
.
view
(
-
1
),
num_prune
,
largest
=
False
)[
0
].
max
()
mask_weight
=
torch
.
gt
(
w_l2_norm
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
mask_weight
=
torch
.
gt
(
w_l2_norm
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
mask_bias
=
torch
.
gt
(
w_l2_norm
,
threshold
).
type_as
(
weight
)
mask_bias
=
torch
.
gt
(
w_l2_norm
,
threshold
).
type_as
(
weight
)
.
detach
()
if
base_mask
[
'bias_mask'
]
is
not
None
else
None
return
{
'weight_mask'
:
mask_weight
.
detach
(),
'bias_mask'
:
mask_bias
.
detach
()
}
return
{
'weight_mask'
:
mask_weight
.
detach
(),
'bias_mask'
:
mask_bias
}
class
FPGMPruner
(
WeightRankFilterPruner
):
class
FPGMPruner
(
WeightRankFilterPruner
):
...
...
src/sdk/pynni/tests/test_pruners.py
0 → 100644
View file @
5eb95c2d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
math
from
unittest
import
TestCase
,
main
from
nni.compression.torch
import
LevelPruner
,
SlimPruner
,
FPGMPruner
,
L1FilterPruner
,
\
L2FilterPruner
,
AGP_Pruner
,
ActivationMeanRankFilterPruner
,
ActivationAPoZRankFilterPruner
def
validate_sparsity
(
wrapper
,
sparsity
,
bias
=
False
):
masks
=
[
wrapper
.
weight_mask
]
if
bias
and
wrapper
.
bias_mask
is
not
None
:
masks
.
append
(
wrapper
.
bias_mask
)
for
m
in
masks
:
actual_sparsity
=
(
m
==
0
).
sum
().
item
()
/
m
.
numel
()
msg
=
'actual sparsity: {:.2f}, target sparsity: {:.2f}'
.
format
(
actual_sparsity
,
sparsity
)
assert
math
.
isclose
(
actual_sparsity
,
sparsity
,
abs_tol
=
0.1
),
msg
prune_config
=
{
'level'
:
{
'pruner_class'
:
LevelPruner
,
'config_list'
:
[{
'sparsity'
:
0.5
,
'op_types'
:
[
'default'
],
}],
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
False
),
lambda
model
:
validate_sparsity
(
model
.
fc
,
0.5
,
False
)
]
},
'agp'
:
{
'pruner_class'
:
AGP_Pruner
,
'config_list'
:
[{
'initial_sparsity'
:
0
,
'final_sparsity'
:
0.8
,
'start_epoch'
:
0
,
'end_epoch'
:
10
,
'frequency'
:
1
,
'op_types'
:
[
'default'
]
}],
'validators'
:
[]
},
'slim'
:
{
'pruner_class'
:
SlimPruner
,
'config_list'
:
[{
'sparsity'
:
0.7
,
'op_types'
:
[
'BatchNorm2d'
]
}],
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
bn1
,
0.7
,
model
.
bias
)
]
},
'fpgm'
:
{
'pruner_class'
:
FPGMPruner
,
'config_list'
:[{
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
]
}],
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
model
.
bias
)
]
},
'l1'
:
{
'pruner_class'
:
L1FilterPruner
,
'config_list'
:
[{
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
],
}],
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
model
.
bias
)
]
},
'l2'
:
{
'pruner_class'
:
L2FilterPruner
,
'config_list'
:
[{
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
],
}],
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
model
.
bias
)
]
},
'mean_activation'
:
{
'pruner_class'
:
ActivationMeanRankFilterPruner
,
'config_list'
:
[{
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
],
}],
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
model
.
bias
)
]
},
'apoz'
:
{
'pruner_class'
:
ActivationAPoZRankFilterPruner
,
'config_list'
:
[{
'sparsity'
:
0.5
,
'op_types'
:
[
'Conv2d'
],
}],
'validators'
:
[
lambda
model
:
validate_sparsity
(
model
.
conv1
,
0.5
,
model
.
bias
)
]
}
}
class
Model
(
nn
.
Module
):
def
__init__
(
self
,
bias
=
True
):
super
(
Model
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
8
,
kernel_size
=
3
,
padding
=
1
,
bias
=
bias
)
self
.
bn1
=
nn
.
BatchNorm2d
(
8
)
self
.
pool
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
fc
=
nn
.
Linear
(
8
,
2
,
bias
=
bias
)
self
.
bias
=
bias
def
forward
(
self
,
x
):
return
self
.
fc
(
self
.
pool
(
self
.
bn1
(
self
.
conv1
(
x
))).
view
(
x
.
size
(
0
),
-
1
))
def
pruners_test
(
pruner_names
=
[
'level'
,
'agp'
,
'slim'
,
'fpgm'
,
'l1'
,
'l2'
,
'mean_activation'
,
'apoz'
],
bias
=
True
):
for
pruner_name
in
pruner_names
:
print
(
'testing {}...'
.
format
(
pruner_name
))
model
=
Model
(
bias
=
bias
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
)
config_list
=
prune_config
[
pruner_name
][
'config_list'
]
x
=
torch
.
randn
(
2
,
1
,
28
,
28
)
y
=
torch
.
tensor
([
0
,
1
]).
long
()
out
=
model
(
x
)
loss
=
F
.
cross_entropy
(
out
,
y
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
optimizer
)
pruner
.
compress
()
x
=
torch
.
randn
(
2
,
1
,
28
,
28
)
y
=
torch
.
tensor
([
0
,
1
]).
long
()
out
=
model
(
x
)
loss
=
F
.
cross_entropy
(
out
,
y
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
pruner
.
export_model
(
'./model_tmp.pth'
,
'./mask_tmp.pth'
,
'./onnx_tmp.pth'
,
input_shape
=
(
2
,
1
,
28
,
28
))
for
v
in
prune_config
[
pruner_name
][
'validators'
]:
v
(
model
)
os
.
remove
(
'./model_tmp.pth'
)
os
.
remove
(
'./mask_tmp.pth'
)
os
.
remove
(
'./onnx_tmp.pth'
)
class
PrunerTestCase
(
TestCase
):
def
test_pruners
(
self
):
pruners_test
(
bias
=
True
)
def
test_pruners_no_bias
(
self
):
pruners_test
(
bias
=
False
)
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment