Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
68644f59
Unverified
Commit
68644f59
authored
Jul 13, 2021
by
J-shang
Committed by
GitHub
Jul 13, 2021
Browse files
add exclude config validate in compressor (#3815)
parent
fb3c596b
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
101 additions
and
60 deletions
+101
-60
examples/model_compress/pruning/basic_pruners_torch.py
examples/model_compress/pruning/basic_pruners_torch.py
+5
-1
nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py
...ithms/compression/pytorch/pruning/auto_compress_pruner.py
+8
-6
nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py
...ms/compression/pytorch/pruning/dependency_aware_pruner.py
+3
-6
nni/algorithms/compression/pytorch/pruning/iterative_pruner.py
...lgorithms/compression/pytorch/pruning/iterative_pruner.py
+16
-12
nni/algorithms/compression/pytorch/pruning/lottery_ticket.py
nni/algorithms/compression/pytorch/pruning/lottery_ticket.py
+5
-4
nni/algorithms/compression/pytorch/pruning/net_adapt_pruner.py
...lgorithms/compression/pytorch/pruning/net_adapt_pruner.py
+8
-6
nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py
...algorithms/compression/pytorch/pruning/one_shot_pruner.py
+5
-4
nni/algorithms/compression/pytorch/pruning/sensitivity_pruner.py
...orithms/compression/pytorch/pruning/sensitivity_pruner.py
+8
-6
nni/algorithms/compression/pytorch/pruning/simulated_annealing_pruner.py
...compression/pytorch/pruning/simulated_annealing_pruner.py
+8
-6
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+13
-9
nni/compression/pytorch/utils/config_validation.py
nni/compression/pytorch/utils/config_validation.py
+22
-0
No files found.
examples/model_compress/pruning/basic_pruners_torch.py
View file @
68644f59
...
...
@@ -243,6 +243,7 @@ def main(args):
# Reproduced result in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
# Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
# If you want to skip some layer, you can use 'exclude' like follow.
if
args
.
pruner
==
'slim'
:
config_list
=
[{
'sparsity'
:
args
.
sparsity
,
...
...
@@ -252,7 +253,10 @@ def main(args):
config_list
=
[{
'sparsity'
:
args
.
sparsity
,
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
'feature.0'
,
'feature.24'
,
'feature.27'
,
'feature.30'
,
'feature.34'
,
'feature.37'
]
'op_names'
:
[
'feature.0'
,
'feature.10'
,
'feature.24'
,
'feature.27'
,
'feature.30'
,
'feature.34'
,
'feature.37'
]
},
{
'exclude'
:
True
,
'op_names'
:
[
'feature.10'
]
}]
pruner
=
pruner_cls
(
model
,
config_list
,
**
kw_args
)
...
...
nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py
View file @
68644f59
...
...
@@ -11,7 +11,7 @@ from nni.utils import OptimizeMode
from
nni.compression.pytorch
import
ModelSpeedup
from
nni.compression.pytorch.compressor
import
Pruner
from
nni.compression.pytorch.utils.config_validation
import
Compresso
rSchema
from
nni.compression.pytorch.utils.config_validation
import
Prune
rSchema
from
.simulated_annealing_pruner
import
SimulatedAnnealingPruner
from
.iterative_pruner
import
ADMMPruner
...
...
@@ -130,16 +130,18 @@ class AutoCompressPruner(Pruner):
"""
if
self
.
_base_algo
==
'level'
:
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
_logger
)
elif
self
.
_base_algo
in
[
'l1'
,
'l2'
,
'fpgm'
]:
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'op_types'
:
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
...
...
nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py
View file @
68644f59
...
...
@@ -2,10 +2,10 @@
# Licensed under the MIT license.
import
logging
from
schema
import
And
,
Optional
,
SchemaError
from
schema
import
And
,
Optional
from
nni.common.graph_utils
import
TorchModuleGraph
from
nni.compression.pytorch.utils.shape_dependency
import
ChannelDependency
,
GroupDependency
from
nni.compression.pytorch.utils.config_validation
import
Compresso
rSchema
from
nni.compression.pytorch.utils.config_validation
import
Prune
rSchema
from
nni.compression.pytorch.compressor
import
Pruner
from
.constants
import
MASKER_DICT
...
...
@@ -82,7 +82,7 @@ class DependencyAwarePruner(Pruner):
self
.
_dependency_update_mask
()
def
validate_config
(
self
,
model
,
config_list
):
schema
=
Compresso
rSchema
([{
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
):
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
],
...
...
@@ -90,9 +90,6 @@ class DependencyAwarePruner(Pruner):
}],
model
,
logger
)
schema
.
validate
(
config_list
)
for
config
in
config_list
:
if
'exclude'
not
in
config
and
'sparsity'
not
in
config
:
raise
SchemaError
(
'Either sparisty or exclude must be specified!'
)
def
_supported_dependency_aware
(
self
):
raise
NotImplementedError
...
...
nni/algorithms/compression/pytorch/pruning/iterative_pruner.py
View file @
68644f59
...
...
@@ -5,7 +5,7 @@ import logging
import
copy
import
torch
from
schema
import
And
,
Optional
from
nni.compression.pytorch.utils.config_validation
import
Compresso
rSchema
from
nni.compression.pytorch.utils.config_validation
import
Prune
rSchema
from
.constants
import
MASKER_DICT
from
.dependency_aware_pruner
import
DependencyAwarePruner
...
...
@@ -138,10 +138,11 @@ class AGPPruner(IterativePruner):
config_list : list
List on pruning configs
"""
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<=
n
<=
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<=
n
<=
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
schema
.
validate
(
config_list
)
...
...
@@ -300,16 +301,18 @@ class ADMMPruner(IterativePruner):
"""
if
self
.
_base_algo
==
'level'
:
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
elif
self
.
_base_algo
in
[
'l1'
,
'l2'
,
'fpgm'
]:
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'op_types'
:
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
schema
.
validate
(
config_list
)
...
...
@@ -436,10 +439,11 @@ class SlimPruner(IterativePruner):
self
.
patch_optimizer_before
(
self
.
_callback
)
def
validate_config
(
self
,
model
,
config_list
):
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'op_types'
:
[
'BatchNorm2d'
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
schema
.
validate
(
config_list
)
...
...
nni/algorithms/compression/pytorch/pruning/lottery_ticket.py
View file @
68644f59
...
...
@@ -5,7 +5,7 @@ import copy
import
logging
import
torch
from
schema
import
And
,
Optional
from
nni.compression.pytorch.utils.config_validation
import
Compresso
rSchema
from
nni.compression.pytorch.utils.config_validation
import
Prune
rSchema
from
nni.compression.pytorch.compressor
import
Pruner
from
.finegrained_pruning_masker
import
LevelPrunerMasker
...
...
@@ -56,11 +56,12 @@ class LotteryTicketPruner(Pruner):
- prune_iterations : The number of rounds for the iterative pruning.
- sparsity : The final sparsity when the compression is done.
"""
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'prune_iterations'
:
And
(
int
,
lambda
n
:
n
>
0
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
schema
.
validate
(
config_list
)
...
...
nni/algorithms/compression/pytorch/pruning/net_adapt_pruner.py
View file @
68644f59
...
...
@@ -11,7 +11,7 @@ from schema import And, Optional
from
nni.utils
import
OptimizeMode
from
nni.compression.pytorch.compressor
import
Pruner
from
nni.compression.pytorch.utils.config_validation
import
Compresso
rSchema
from
nni.compression.pytorch.utils.config_validation
import
Prune
rSchema
from
nni.compression.pytorch.utils.num_param_counter
import
get_total_num_weights
from
.constants_pruner
import
PRUNER_DICT
...
...
@@ -120,16 +120,18 @@ class NetAdaptPruner(Pruner):
"""
if
self
.
_base_algo
==
'level'
:
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
_logger
)
elif
self
.
_base_algo
in
[
'l1'
,
'l2'
,
'fpgm'
]:
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'op_types'
:
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
...
...
nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py
View file @
68644f59
...
...
@@ -4,7 +4,7 @@
import
logging
from
schema
import
And
,
Optional
from
nni.compression.pytorch.utils.config_validation
import
Compresso
rSchema
from
nni.compression.pytorch.utils.config_validation
import
Prune
rSchema
from
.dependency_aware_pruner
import
DependencyAwarePruner
__all__
=
[
'LevelPruner'
,
'L1FilterPruner'
,
'L2FilterPruner'
,
'FPGMPruner'
]
...
...
@@ -48,10 +48,11 @@ class OneshotPruner(DependencyAwarePruner):
config_list : list
List on pruning configs
"""
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
schema
.
validate
(
config_list
)
...
...
nni/algorithms/compression/pytorch/pruning/sensitivity_pruner.py
View file @
68644f59
...
...
@@ -9,7 +9,7 @@ import torch
from
schema
import
And
,
Optional
from
nni.compression.pytorch.compressor
import
Pruner
from
nni.compression.pytorch.utils.config_validation
import
Compresso
rSchema
from
nni.compression.pytorch.utils.config_validation
import
Prune
rSchema
from
nni.compression.pytorch.utils.sensitivity_analysis
import
SensitivityAnalysis
from
.constants_pruner
import
PRUNER_DICT
...
...
@@ -146,16 +146,18 @@ class SensitivityPruner(Pruner):
"""
if
self
.
base_algo
==
'level'
:
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
_logger
)
elif
self
.
base_algo
in
[
'l1'
,
'l2'
,
'fpgm'
]:
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'op_types'
:
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
...
...
nni/algorithms/compression/pytorch/pruning/simulated_annealing_pruner.py
View file @
68644f59
...
...
@@ -13,7 +13,7 @@ from schema import And, Optional
from
nni.utils
import
OptimizeMode
from
nni.compression.pytorch.compressor
import
Pruner
from
nni.compression.pytorch.utils.config_validation
import
Compresso
rSchema
from
nni.compression.pytorch.utils.config_validation
import
Prune
rSchema
from
.constants_pruner
import
PRUNER_DICT
...
...
@@ -115,16 +115,18 @@ class SimulatedAnnealingPruner(Pruner):
"""
if
self
.
_base_algo
==
'level'
:
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
_logger
)
elif
self
.
_base_algo
in
[
'l1'
,
'l2'
,
'fpgm'
]:
schema
=
Compresso
rSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
schema
=
Prune
rSchema
([{
Optional
(
'sparsity'
)
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'op_types'
:
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
_logger
)
schema
.
validate
(
config_list
)
...
...
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
68644f59
...
...
@@ -5,7 +5,7 @@ import logging
import
copy
import
torch
from
schema
import
Schema
,
And
,
Or
,
Optional
from
nni.compression.pytorch.utils.config_validation
import
Compresso
rSchema
from
nni.compression.pytorch.utils.config_validation
import
Quantize
rSchema
from
nni.compression.pytorch.compressor
import
Quantizer
,
QuantForward
,
QuantGrad
,
QuantType
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
,
'LsqQuantizer'
]
...
...
@@ -22,11 +22,12 @@ class NaiveQuantizer(Quantizer):
self
.
layer_scale
=
{}
def
validate_config
(
self
,
model
,
config_list
):
schema
=
Compresso
rSchema
([{
schema
=
Quantize
rSchema
([{
Optional
(
'quant_types'
):
[
'weight'
],
Optional
(
'quant_bits'
):
Or
(
8
,
{
'weight'
:
8
}),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
schema
.
validate
(
config_list
)
...
...
@@ -183,7 +184,7 @@ class QAT_Quantizer(Quantizer):
config_list : list of dict
List of configurations
"""
schema
=
Compresso
rSchema
([{
schema
=
Quantize
rSchema
([{
Optional
(
'quant_types'
):
Schema
([
lambda
x
:
x
in
[
'weight'
,
'output'
]]),
Optional
(
'quant_bits'
):
Or
(
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Schema
({
Optional
(
'weight'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
...
...
@@ -191,7 +192,8 @@ class QAT_Quantizer(Quantizer):
})),
Optional
(
'quant_start_step'
):
And
(
int
,
lambda
n
:
n
>=
0
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
schema
.
validate
(
config_list
)
...
...
@@ -386,13 +388,14 @@ class DoReFaQuantizer(Quantizer):
config_list : list of dict
List of configurations
"""
schema
=
Compresso
rSchema
([{
schema
=
Quantize
rSchema
([{
Optional
(
'quant_types'
):
Schema
([
lambda
x
:
x
in
[
'weight'
]]),
Optional
(
'quant_bits'
):
Or
(
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Schema
({
Optional
(
'weight'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
)
})),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
schema
.
validate
(
config_list
)
...
...
@@ -493,14 +496,15 @@ class BNNQuantizer(Quantizer):
config_list : list of dict
List of configurations
"""
schema
=
Compresso
rSchema
([{
schema
=
Quantize
rSchema
([{
Optional
(
'quant_types'
):
Schema
([
lambda
x
:
x
in
[
'weight'
,
'output'
]]),
Optional
(
'quant_bits'
):
Or
(
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Schema
({
Optional
(
'weight'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Optional
(
'output'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
})),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
Optional
(
'op_names'
):
[
str
],
Optional
(
'exclude'
):
bool
}],
model
,
logger
)
schema
.
validate
(
config_list
)
...
...
nni/compression/pytorch/utils/config_validation.py
View file @
68644f59
...
...
@@ -51,3 +51,25 @@ class CompressorSchema:
def
validate
(
self
,
data
):
self
.
compressor_schema
.
validate
(
data
)
def
validate_exclude_sparsity
(
data
):
if
not
(
'exclude'
in
data
or
'sparsity'
in
data
):
raise
SchemaError
(
'Either sparisty or exclude must be specified.'
)
return
True
def
validate_exclude_quant_types_quant_bits
(
data
):
if
not
(
'exclude'
in
data
or
(
'quant_types'
in
data
and
'quant_bits'
in
data
)):
raise
SchemaError
(
'Either (quant_types and quant_bits) or exclude must be specified.'
)
return
True
class
PrunerSchema
(
CompressorSchema
):
def
_modify_schema
(
self
,
data_schema
,
model
,
logger
):
data_schema
=
super
().
_modify_schema
(
data_schema
,
model
,
logger
)
data_schema
[
0
]
=
And
(
data_schema
[
0
],
lambda
d
:
validate_exclude_sparsity
(
d
))
return
data_schema
class
QuantizerSchema
(
CompressorSchema
):
def
_modify_schema
(
self
,
data_schema
,
model
,
logger
):
data_schema
=
super
().
_modify_schema
(
data_schema
,
model
,
logger
)
data_schema
[
0
]
=
And
(
data_schema
[
0
],
lambda
d
:
validate_exclude_quant_types_quant_bits
(
d
))
return
data_schema
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment