Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d1bc0cfc
Unverified
Commit
d1bc0cfc
authored
Mar 24, 2020
by
chicm-ms
Committed by
GitHub
Mar 24, 2020
Browse files
Add model compression config validation (#2219)
parent
e0b692c9
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
335 additions
and
12 deletions
+335
-12
src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py
...i/nni/compression/torch/activation_rank_filter_pruners.py
+20
-0
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+13
-0
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+79
-11
src/sdk/pynni/nni/compression/torch/quantizers.py
src/sdk/pynni/nni/compression/torch/quantizers.py
+75
-0
src/sdk/pynni/nni/compression/torch/utils.py
src/sdk/pynni/nni/compression/torch/utils.py
+53
-0
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
...pynni/nni/compression/torch/weight_rank_filter_pruners.py
+20
-0
src/sdk/pynni/tests/test_compressor.py
src/sdk/pynni/tests/test_compressor.py
+74
-0
src/sdk/pynni/tests/test_pruners.py
src/sdk/pynni/tests/test_pruners.py
+1
-1
No files found.
src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py
View file @
d1bc0cfc
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
import
logging
import
logging
import
torch
import
torch
from
schema
import
And
,
Optional
from
.utils
import
CompressorSchema
from
.compressor
import
Pruner
from
.compressor
import
Pruner
__all__
=
[
'ActivationAPoZRankFilterPruner'
,
'ActivationMeanRankFilterPruner'
]
__all__
=
[
'ActivationAPoZRankFilterPruner'
,
'ActivationMeanRankFilterPruner'
]
...
@@ -50,6 +52,24 @@ class ActivationRankFilterPruner(Pruner):
...
@@ -50,6 +52,24 @@ class ActivationRankFilterPruner(Pruner):
else
:
else
:
self
.
activation
=
None
self
.
activation
=
None
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
d1bc0cfc
...
@@ -40,6 +40,9 @@ class Compressor:
...
@@ -40,6 +40,9 @@ class Compressor:
optimizer: pytorch optimizer
optimizer: pytorch optimizer
optimizer used to train the model
optimizer used to train the model
"""
"""
assert
isinstance
(
model
,
torch
.
nn
.
Module
)
self
.
validate_config
(
model
,
config_list
)
self
.
bound_model
=
model
self
.
bound_model
=
model
self
.
config_list
=
config_list
self
.
config_list
=
config_list
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
...
@@ -54,9 +57,17 @@ class Compressor:
...
@@ -54,9 +57,17 @@ class Compressor:
for
layer
,
config
in
self
.
_detect_modules_to_compress
():
for
layer
,
config
in
self
.
_detect_modules_to_compress
():
wrapper
=
self
.
_wrap_modules
(
layer
,
config
)
wrapper
=
self
.
_wrap_modules
(
layer
,
config
)
self
.
modules_wrapper
.
append
(
wrapper
)
self
.
modules_wrapper
.
append
(
wrapper
)
if
not
self
.
modules_wrapper
:
_logger
.
warning
(
'Nothing is configured to compress, please check your model and config_list'
)
self
.
_wrap_model
()
self
.
_wrap_model
()
def
validate_config
(
self
,
model
,
config_list
):
"""
subclass can optionally implement this method to check if config_list if valid
"""
pass
def
_detect_modules_to_compress
(
self
):
def
_detect_modules_to_compress
(
self
):
"""
"""
detect all modules should be compressed, and save the result in `self.modules_to_compress`.
detect all modules should be compressed, and save the result in `self.modules_to_compress`.
...
@@ -65,6 +76,8 @@ class Compressor:
...
@@ -65,6 +76,8 @@ class Compressor:
if
self
.
modules_to_compress
is
None
:
if
self
.
modules_to_compress
is
None
:
self
.
modules_to_compress
=
[]
self
.
modules_to_compress
=
[]
for
name
,
module
in
self
.
bound_model
.
named_modules
():
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
module
==
self
.
bound_model
:
continue
layer
=
LayerInfo
(
name
,
module
)
layer
=
LayerInfo
(
name
,
module
)
config
=
self
.
select_config
(
layer
)
config
=
self
.
select_config
(
layer
)
if
config
is
not
None
:
if
config
is
not
None
:
...
...
src/sdk/pynni/nni/compression/torch/pruners.py
View file @
d1bc0cfc
...
@@ -4,7 +4,9 @@
...
@@ -4,7 +4,9 @@
import
copy
import
copy
import
logging
import
logging
import
torch
import
torch
from
schema
import
And
,
Optional
from
.compressor
import
Pruner
from
.compressor
import
Pruner
from
.utils
import
CompressorSchema
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'SlimPruner'
,
'LotteryTicketPruner'
]
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'SlimPruner'
,
'LotteryTicketPruner'
]
...
@@ -31,6 +33,23 @@ class LevelPruner(Pruner):
...
@@ -31,6 +33,23 @@ class LevelPruner(Pruner):
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
"""
"""
Calculate the mask of given layer
Calculate the mask of given layer
...
@@ -90,6 +109,27 @@ class AGP_Pruner(Pruner):
...
@@ -90,6 +109,27 @@ class AGP_Pruner(Pruner):
self
.
now_epoch
=
0
self
.
now_epoch
=
0
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
schema
=
CompressorSchema
([{
'initial_sparsity'
:
And
(
float
,
lambda
n
:
0
<=
n
<=
1
),
'final_sparsity'
:
And
(
float
,
lambda
n
:
0
<=
n
<=
1
),
'start_epoch'
:
And
(
int
,
lambda
n
:
n
>=
0
),
'end_epoch'
:
And
(
int
,
lambda
n
:
n
>=
0
),
'frequency'
:
And
(
int
,
lambda
n
:
n
>
0
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
...
@@ -208,6 +248,24 @@ class SlimPruner(Pruner):
...
@@ -208,6 +248,24 @@ class SlimPruner(Pruner):
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'op_types'
:
[
'BatchNorm2d'
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
...
@@ -273,7 +331,7 @@ class LotteryTicketPruner(Pruner):
...
@@ -273,7 +331,7 @@ class LotteryTicketPruner(Pruner):
"""
"""
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
curr_prune_iteration
=
None
self
.
curr_prune_iteration
=
None
self
.
prune_iterations
=
self
.
_validate_config
(
config_list
)
self
.
prune_iterations
=
config_list
[
0
][
'prune_iterations'
]
# save init weights and optimizer
# save init weights and optimizer
self
.
reset_weights
=
reset_weights
self
.
reset_weights
=
reset_weights
...
@@ -286,16 +344,26 @@ class LotteryTicketPruner(Pruner):
...
@@ -286,16 +344,26 @@ class LotteryTicketPruner(Pruner):
if
lr_scheduler
is
not
None
:
if
lr_scheduler
is
not
None
:
self
.
_scheduler_state
=
copy
.
deepcopy
(
lr_scheduler
.
state_dict
())
self
.
_scheduler_state
=
copy
.
deepcopy
(
lr_scheduler
.
state_dict
())
def
_validate_config
(
self
,
config_list
):
def
validate_config
(
self
,
model
,
config_list
):
prune_iterations
=
None
"""
for
config
in
config_list
:
Parameters
assert
'prune_iterations'
in
config
,
'prune_iterations must exist in your config'
----------
assert
'sparsity'
in
config
,
'sparsity must exist in your config'
model : torch.nn.module
if
prune_iterations
is
not
None
:
Model to be pruned
assert
prune_iterations
==
config
[
config_list : list
'prune_iterations'
],
'The values of prune_iterations must be equal in your config'
Supported keys:
prune_iterations
=
config
[
'prune_iterations'
]
- prune_iterations : The number of rounds for the iterative pruning.
return
prune_iterations
- sparsity : The final sparsity when the compression is done.
"""
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
'prune_iterations'
:
And
(
int
,
lambda
n
:
n
>
0
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
assert
len
(
set
([
x
[
'prune_iterations'
]
for
x
in
config_list
]))
==
1
,
'The values of prune_iterations must be equal in your config'
def
_calc_sparsity
(
self
,
sparsity
):
def
_calc_sparsity
(
self
,
sparsity
):
keep_ratio_once
=
(
1
-
sparsity
)
**
(
1
/
self
.
prune_iterations
)
keep_ratio_once
=
(
1
-
sparsity
)
**
(
1
/
self
.
prune_iterations
)
...
...
src/sdk/pynni/nni/compression/torch/quantizers.py
View file @
d1bc0cfc
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
import
logging
import
logging
import
torch
import
torch
from
schema
import
Schema
,
And
,
Or
,
Optional
from
.utils
import
CompressorSchema
from
.compressor
import
Quantizer
,
QuantGrad
,
QuantType
from
.compressor
import
Quantizer
,
QuantGrad
,
QuantType
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
]
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
]
...
@@ -17,6 +19,16 @@ class NaiveQuantizer(Quantizer):
...
@@ -17,6 +19,16 @@ class NaiveQuantizer(Quantizer):
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
layer_scale
=
{}
self
.
layer_scale
=
{}
def
validate_config
(
self
,
model
,
config_list
):
schema
=
CompressorSchema
([{
Optional
(
'quant_types'
):
[
'weight'
],
Optional
(
'quant_bits'
):
Or
(
8
,
{
'weight'
:
8
}),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
new_scale
=
weight
.
abs
().
max
()
/
127
new_scale
=
weight
.
abs
().
max
()
/
127
scale
=
max
(
self
.
layer_scale
.
get
(
wrapper
.
name
,
0
),
new_scale
)
scale
=
max
(
self
.
layer_scale
.
get
(
wrapper
.
name
,
0
),
new_scale
)
...
@@ -137,6 +149,28 @@ class QAT_Quantizer(Quantizer):
...
@@ -137,6 +149,28 @@ class QAT_Quantizer(Quantizer):
layer
.
module
.
register_buffer
(
'tracked_max_biased'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_biased'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max'
,
torch
.
zeros
(
1
))
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list of dict
List of configurations
"""
schema
=
CompressorSchema
([{
Optional
(
'quant_types'
):
Schema
([
lambda
x
:
x
in
[
'weight'
,
'output'
]]),
Optional
(
'quant_bits'
):
Or
(
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Schema
({
Optional
(
'weight'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Optional
(
'output'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
})),
Optional
(
'quant_start_step'
):
And
(
int
,
lambda
n
:
n
>=
0
),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
def
_quantize
(
self
,
bits
,
op
,
real_val
):
def
_quantize
(
self
,
bits
,
op
,
real_val
):
"""
"""
quantize real value.
quantize real value.
...
@@ -233,6 +267,26 @@ class DoReFaQuantizer(Quantizer):
...
@@ -233,6 +267,26 @@ class DoReFaQuantizer(Quantizer):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list of dict
List of configurations
"""
schema
=
CompressorSchema
([{
Optional
(
'quant_types'
):
Schema
([
lambda
x
:
x
in
[
'weight'
]]),
Optional
(
'quant_bits'
):
Or
(
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Schema
({
Optional
(
'weight'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
)
})),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
weight_bits
=
get_bits_length
(
wrapper
.
config
,
'weight'
)
weight_bits
=
get_bits_length
(
wrapper
.
config
,
'weight'
)
out
=
weight
.
tanh
()
out
=
weight
.
tanh
()
...
@@ -264,6 +318,27 @@ class BNNQuantizer(Quantizer):
...
@@ -264,6 +318,27 @@ class BNNQuantizer(Quantizer):
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
ClipGrad
self
.
quant_grad
=
ClipGrad
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list of dict
List of configurations
"""
schema
=
CompressorSchema
([{
Optional
(
'quant_types'
):
Schema
([
lambda
x
:
x
in
[
'weight'
,
'output'
]]),
Optional
(
'quant_bits'
):
Or
(
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Schema
({
Optional
(
'weight'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Optional
(
'output'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
})),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
out
=
torch
.
sign
(
weight
)
out
=
torch
.
sign
(
weight
)
# remove zeros
# remove zeros
...
...
src/sdk/pynni/nni/compression/torch/utils.py
0 → 100644
View file @
d1bc0cfc
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
schema
import
Schema
,
And
,
SchemaError
def
validate_op_names
(
model
,
op_names
,
logger
):
found_names
=
set
(
map
(
lambda
x
:
x
[
0
],
model
.
named_modules
()))
not_found_op_names
=
list
(
set
(
op_names
)
-
found_names
)
if
not_found_op_names
:
logger
.
warning
(
'op_names %s not found in model'
,
not_found_op_names
)
return
True
def
validate_op_types
(
model
,
op_types
,
logger
):
found_types
=
set
([
'default'
])
|
set
(
map
(
lambda
x
:
type
(
x
[
1
]).
__name__
,
model
.
named_modules
()))
not_found_op_types
=
list
(
set
(
op_types
)
-
found_types
)
if
not_found_op_types
:
logger
.
warning
(
'op_types %s not found in model'
,
not_found_op_types
)
return
True
def
validate_op_types_op_names
(
data
):
if
not
(
'op_types'
in
data
or
'op_names'
in
data
):
raise
SchemaError
(
'Either op_types or op_names must be specified.'
)
return
True
class
CompressorSchema
:
def
__init__
(
self
,
data_schema
,
model
,
logger
):
assert
isinstance
(
data_schema
,
list
)
and
len
(
data_schema
)
<=
1
self
.
data_schema
=
data_schema
self
.
compressor_schema
=
Schema
(
self
.
_modify_schema
(
data_schema
,
model
,
logger
))
def
_modify_schema
(
self
,
data_schema
,
model
,
logger
):
if
not
data_schema
:
return
data_schema
for
k
in
data_schema
[
0
]:
old_schema
=
data_schema
[
0
][
k
]
if
k
==
'op_types'
or
(
isinstance
(
k
,
Schema
)
and
k
.
_schema
==
'op_types'
):
new_schema
=
And
(
old_schema
,
lambda
n
:
validate_op_types
(
model
,
n
,
logger
))
data_schema
[
0
][
k
]
=
new_schema
if
k
==
'op_names'
or
(
isinstance
(
k
,
Schema
)
and
k
.
_schema
==
'op_names'
):
new_schema
=
And
(
old_schema
,
lambda
n
:
validate_op_names
(
model
,
n
,
logger
))
data_schema
[
0
][
k
]
=
new_schema
data_schema
[
0
]
=
And
(
data_schema
[
0
],
lambda
d
:
validate_op_types_op_names
(
d
))
return
data_schema
def
validate
(
self
,
data
):
self
.
compressor_schema
.
validate
(
data
)
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
View file @
d1bc0cfc
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
import
logging
import
logging
import
torch
import
torch
from
schema
import
And
,
Optional
from
.utils
import
CompressorSchema
from
.compressor
import
Pruner
from
.compressor
import
Pruner
__all__
=
[
'L1FilterPruner'
,
'L2FilterPruner'
,
'FPGMPruner'
]
__all__
=
[
'L1FilterPruner'
,
'L2FilterPruner'
,
'FPGMPruner'
]
...
@@ -31,6 +33,24 @@ class WeightRankFilterPruner(Pruner):
...
@@ -31,6 +33,24 @@ class WeightRankFilterPruner(Pruner):
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
def
validate_config
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
schema
=
CompressorSchema
([{
'sparsity'
:
And
(
float
,
lambda
n
:
0
<
n
<
1
),
Optional
(
'op_types'
):
[
'Conv2d'
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
...
...
src/sdk/pynni/tests/test_compressor.py
View file @
d1bc0cfc
...
@@ -6,6 +6,7 @@ import numpy as np
...
@@ -6,6 +6,7 @@ import numpy as np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
schema
import
nni.compression.torch
as
torch_compressor
import
nni.compression.torch
as
torch_compressor
import
math
import
math
...
@@ -267,6 +268,79 @@ class CompressorTestCase(TestCase):
...
@@ -267,6 +268,79 @@ class CompressorTestCase(TestCase):
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_biased
,
0.00998
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_biased
,
0.00998
,
abs_tol
=
eps
)
def
test_torch_pruner_validation
(
self
):
# test bad configuraiton
pruner_classes
=
[
torch_compressor
.
__dict__
[
x
]
for
x
in
\
[
'LevelPruner'
,
'SlimPruner'
,
'FPGMPruner'
,
'L1FilterPruner'
,
'L2FilterPruner'
,
'AGP_Pruner'
,
\
'ActivationMeanRankFilterPruner'
,
'ActivationAPoZRankFilterPruner'
]]
bad_configs
=
[
[
{
'sparsity'
:
'0.2'
},
{
'sparsity'
:
0.6
}
],
[
{
'sparsity'
:
0.2
},
{
'sparsity'
:
1.6
}
],
[
{
'sparsity'
:
0.2
,
'op_types'
:
'default'
},
{
'sparsity'
:
0.6
}
],
[
{
'sparsity'
:
0.2
},
{
'sparsity'
:
0.6
,
'op_names'
:
'abc'
}
]
]
model
=
TorchModel
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
)
for
pruner_class
in
pruner_classes
:
for
config_list
in
bad_configs
:
try
:
pruner_class
(
model
,
config_list
,
optimizer
)
print
(
config_list
)
assert
False
,
'Validation error should be raised for bad configuration'
except
schema
.
SchemaError
:
pass
except
:
print
(
'FAILED:'
,
pruner_class
,
config_list
)
raise
def
test_torch_quantizer_validation
(
self
):
# test bad configuraiton
quantizer_classes
=
[
torch_compressor
.
__dict__
[
x
]
for
x
in
\
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
]]
bad_configs
=
[
[
{
'bad_key'
:
'abc'
}
],
[
{
'quant_types'
:
'abc'
}
],
[
{
'quant_bits'
:
34
}
],
[
{
'op_types'
:
'default'
}
],
[
{
'quant_bits'
:
{
'abc'
:
123
}}
]
]
model
=
TorchModel
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
)
for
quantizer_class
in
quantizer_classes
:
for
config_list
in
bad_configs
:
try
:
quantizer_class
(
model
,
config_list
,
optimizer
)
print
(
config_list
)
assert
False
,
'Validation error should be raised for bad configuration'
except
schema
.
SchemaError
:
pass
except
:
print
(
'FAILED:'
,
quantizer_class
,
config_list
)
raise
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
src/sdk/pynni/tests/test_pruners.py
View file @
d1bc0cfc
...
@@ -34,7 +34,7 @@ prune_config = {
...
@@ -34,7 +34,7 @@ prune_config = {
'agp'
:
{
'agp'
:
{
'pruner_class'
:
AGP_Pruner
,
'pruner_class'
:
AGP_Pruner
,
'config_list'
:
[{
'config_list'
:
[{
'initial_sparsity'
:
0
,
'initial_sparsity'
:
0
.
,
'final_sparsity'
:
0.8
,
'final_sparsity'
:
0.8
,
'start_epoch'
:
0
,
'start_epoch'
:
0
,
'end_epoch'
:
10
,
'end_epoch'
:
10
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment