Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d0a9b106
Unverified
Commit
d0a9b106
authored
Aug 11, 2020
by
Guoxin
Committed by
GitHub
Aug 11, 2020
Browse files
fix IT pruning example issue (#2772)
parent
654e8242
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
17 deletions
+33
-17
src/sdk/pynni/nni/compression/torch/pruning/one_shot.py
src/sdk/pynni/nni/compression/torch/pruning/one_shot.py
+26
-10
src/sdk/pynni/tests/test_compressor.py
src/sdk/pynni/tests/test_compressor.py
+6
-4
src/sdk/pynni/tests/test_pruners.py
src/sdk/pynni/tests/test_pruners.py
+1
-3
No files found.
src/sdk/pynni/nni/compression/torch/pruning/one_shot.py
View file @
d0a9b106
...
...
@@ -94,9 +94,11 @@ class LevelPruner(OneshotPruner):
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Operation types to prune.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'level'
)
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'level'
,
optimizer
=
optimizer
)
class
SlimPruner
(
OneshotPruner
):
"""
...
...
@@ -108,9 +110,11 @@ class SlimPruner(OneshotPruner):
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only BatchNorm2d is supported in Slim Pruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'slim'
)
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'slim'
,
optimizer
=
optimizer
)
def
validate_config
(
self
,
model
,
config_list
):
schema
=
CompressorSchema
([{
...
...
@@ -147,9 +151,11 @@ class L1FilterPruner(_StructuredFilterPruner):
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L1FilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'l1'
)
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'l1'
,
optimizer
=
optimizer
)
class
L2FilterPruner
(
_StructuredFilterPruner
):
"""
...
...
@@ -161,9 +167,11 @@ class L2FilterPruner(_StructuredFilterPruner):
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in L2FilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'l2'
)
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'l2'
,
optimizer
=
optimizer
)
class
FPGMPruner
(
_StructuredFilterPruner
):
"""
...
...
@@ -175,9 +183,11 @@ class FPGMPruner(_StructuredFilterPruner):
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only Conv2d is supported in FPGM Pruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'fpgm'
)
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'fpgm'
,
optimizer
=
optimizer
)
class
TaylorFOWeightFilterPruner
(
_StructuredFilterPruner
):
"""
...
...
@@ -189,6 +199,8 @@ class TaylorFOWeightFilterPruner(_StructuredFilterPruner):
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
statistics_batch_num
=
1
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'taylorfo'
,
optimizer
=
optimizer
,
statistics_batch_num
=
statistics_batch_num
)
...
...
@@ -203,6 +215,8 @@ class ActivationAPoZRankFilterPruner(_StructuredFilterPruner):
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'apoz'
,
optimizer
=
optimizer
,
\
...
...
@@ -218,6 +232,8 @@ class ActivationMeanRankFilterPruner(_StructuredFilterPruner):
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
super
().
__init__
(
model
,
config_list
,
pruning_algorithm
=
'mean_activation'
,
optimizer
=
optimizer
,
\
...
...
src/sdk/pynni/tests/test_compressor.py
View file @
d0a9b106
...
...
@@ -88,8 +88,9 @@ class CompressorTestCase(TestCase):
def
test_torch_level_pruner
(
self
):
model
=
TorchModel
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
configure_list
=
[{
'sparsity'
:
0.8
,
'op_types'
:
[
'default'
]}]
torch_compressor
.
LevelPruner
(
model
,
configure_list
).
compress
()
torch_compressor
.
LevelPruner
(
model
,
configure_list
,
optimizer
).
compress
()
@
tf2
def
test_tf_level_pruner
(
self
):
...
...
@@ -128,7 +129,7 @@ class CompressorTestCase(TestCase):
model
=
TorchModel
()
config_list
=
[{
'sparsity'
:
0.6
,
'op_types'
:
[
'Conv2d'
]},
{
'sparsity'
:
0.2
,
'op_types'
:
[
'Conv2d'
]}]
pruner
=
torch_compressor
.
FPGMPruner
(
model
,
config_list
)
pruner
=
torch_compressor
.
FPGMPruner
(
model
,
config_list
,
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
)
)
model
.
conv2
.
module
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
masks
=
pruner
.
calc_mask
(
model
.
conv2
)
...
...
@@ -314,7 +315,7 @@ class CompressorTestCase(TestCase):
def
test_torch_pruner_validation
(
self
):
# test bad configuraiton
pruner_classes
=
[
torch_compressor
.
__dict__
[
x
]
for
x
in
\
[
'LevelPruner'
,
'SlimPruner'
,
'FPGMPruner'
,
'L1FilterPruner'
,
'L2FilterPruner'
,
\
[
'LevelPruner'
,
'SlimPruner'
,
'FPGMPruner'
,
'L1FilterPruner'
,
'L2FilterPruner'
,
'AGPPruner'
,
\
'ActivationMeanRankFilterPruner'
,
'ActivationAPoZRankFilterPruner'
]]
bad_configs
=
[
...
...
@@ -336,10 +337,11 @@ class CompressorTestCase(TestCase):
]
]
model
=
TorchModel
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
)
for
pruner_class
in
pruner_classes
:
for
config_list
in
bad_configs
:
try
:
pruner_class
(
model
,
config_list
)
pruner_class
(
model
,
config_list
,
optimizer
)
print
(
config_list
)
assert
False
,
'Validation error should be raised for bad configuration'
except
schema
.
SchemaError
:
...
...
src/sdk/pynni/tests/test_pruners.py
View file @
d0a9b106
...
...
@@ -192,9 +192,7 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'tayl
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
trainer
=
prune_config
[
pruner_name
][
'trainer'
])
elif
pruner_name
==
'autocompress'
:
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
trainer
=
prune_config
[
pruner_name
][
'trainer'
],
evaluator
=
prune_config
[
pruner_name
][
'evaluator'
],
dummy_input
=
x
)
elif
pruner_name
in
[
'level'
,
'slim'
,
'fpgm'
,
'l1'
,
'l2'
]:
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
)
elif
pruner_name
in
[
'agp'
,
'taylorfo'
,
'mean_activation'
,
'apoz'
]:
else
:
pruner
=
prune_config
[
pruner_name
][
'pruner_class'
](
model
,
config_list
,
optimizer
)
pruner
.
compress
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment