Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
503a3579
Commit
503a3579
authored
Nov 25, 2019
by
Tang Lang
Committed by
chicm-ms
Nov 25, 2019
Browse files
add pruner unit test (#1771)
* add pruner unit test * modify pruners compatible with torch0.4.1
parent
8ac61b77
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
92 additions
and
17 deletions
+92
-17
docs/en_US/Compressor/SlimPruner.md
docs/en_US/Compressor/SlimPruner.md
+1
-1
examples/model_compress/slim_pruner_torch_vgg19.py
examples/model_compress/slim_pruner_torch_vgg19.py
+1
-1
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
+5
-5
src/sdk/pynni/tests/test_compressor.py
src/sdk/pynni/tests/test_compressor.py
+85
-10
No files found.
docs/en_US/Compressor/SlimPruner.md
View file @
503a3579
...
@@ -34,6 +34,6 @@ We implemented one of the experiments in ['Learning Efficient Convolutional Netw
...
@@ -34,6 +34,6 @@ We implemented one of the experiments in ['Learning Efficient Convolutional Netw
| Model | Error(paper/ours) | Parameters | Pruned |
| Model | Error(paper/ours) | Parameters | Pruned |
| ------------- | ----------------- | ---------- | --------- |
| ------------- | ----------------- | ---------- | --------- |
| VGGNet | 6.34/6.40 | 20.04M | |
| VGGNet | 6.34/6.40 | 20.04M | |
| Pruned-VGGNet | 6.20/6.
39
| 2.03M | 88.5% |
| Pruned-VGGNet | 6.20/6.
26
| 2.03M | 88.5% |
The experiments code can be found at
[
examples/model_compress
](
https://github.com/microsoft/nni/tree/master/examples/model_compress/
)
The experiments code can be found at
[
examples/model_compress
](
https://github.com/microsoft/nni/tree/master/examples/model_compress/
)
examples/model_compress/slim_pruner_torch_vgg19.py
View file @
503a3579
...
@@ -169,7 +169,7 @@ def main():
...
@@ -169,7 +169,7 @@ def main():
new_model
.
to
(
device
)
new_model
.
to
(
device
)
new_model
.
load_state_dict
(
torch
.
load
(
'pruned_vgg19_cifar10.pth'
))
new_model
.
load_state_dict
(
torch
.
load
(
'pruned_vgg19_cifar10.pth'
))
test
(
new_model
,
device
,
test_loader
)
test
(
new_model
,
device
,
test_loader
)
# top1 = 93.
61
%
# top1 = 93.
74
%
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
src/sdk/pynni/nni/compression/torch/builtin_pruners.py
View file @
503a3579
...
@@ -47,7 +47,7 @@ class LevelPruner(Pruner):
...
@@ -47,7 +47,7 @@ class LevelPruner(Pruner):
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
if
k
==
0
:
if
k
==
0
:
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)
.
values
.
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)
[
0
]
.
max
()
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
if_init_list
.
update
({
op_name
:
False
})
self
.
if_init_list
.
update
({
op_name
:
False
})
...
@@ -108,7 +108,7 @@ class AGP_Pruner(Pruner):
...
@@ -108,7 +108,7 @@ class AGP_Pruner(Pruner):
return
mask
return
mask
# if we want to generate new mask, we should update weigth first
# if we want to generate new mask, we should update weigth first
w_abs
=
weight
.
abs
()
*
mask
w_abs
=
weight
.
abs
()
*
mask
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)
.
values
.
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)
[
0
]
.
max
()
new_mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
new_mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
self
.
mask_dict
.
update
({
op_name
:
new_mask
})
self
.
mask_dict
.
update
({
op_name
:
new_mask
})
self
.
if_init_list
.
update
({
op_name
:
False
})
self
.
if_init_list
.
update
({
op_name
:
False
})
...
@@ -336,7 +336,7 @@ class L1FilterPruner(Pruner):
...
@@ -336,7 +336,7 @@ class L1FilterPruner(Pruner):
if
k
==
0
:
if
k
==
0
:
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
w_abs_structured
=
w_abs
.
view
(
filters
,
-
1
).
sum
(
dim
=
1
)
w_abs_structured
=
w_abs
.
view
(
filters
,
-
1
).
sum
(
dim
=
1
)
threshold
=
torch
.
topk
(
w_abs_structured
.
view
(
-
1
),
k
,
largest
=
False
)
.
values
.
max
()
threshold
=
torch
.
topk
(
w_abs_structured
.
view
(
-
1
),
k
,
largest
=
False
)
[
0
]
.
max
()
mask
=
torch
.
gt
(
w_abs_structured
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
mask
=
torch
.
gt
(
w_abs_structured
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
finally
:
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
...
@@ -370,10 +370,10 @@ class SlimPruner(Pruner):
...
@@ -370,10 +370,10 @@ class SlimPruner(Pruner):
config
=
config_list
[
0
]
config
=
config_list
[
0
]
for
(
layer
,
config
)
in
self
.
detect_modules_to_compress
():
for
(
layer
,
config
)
in
self
.
detect_modules_to_compress
():
assert
layer
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
assert
layer
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
clone
())
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
all_bn_weights
=
torch
.
cat
(
weight_list
)
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)
.
values
.
max
()
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)
[
0
]
.
max
()
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
):
"""
"""
...
...
src/sdk/pynni/tests/test_compressor.py
View file @
503a3579
...
@@ -8,6 +8,7 @@ import nni.compression.torch as torch_compressor
...
@@ -8,6 +8,7 @@ import nni.compression.torch as torch_compressor
if
tf
.
__version__
>=
'2.0'
:
if
tf
.
__version__
>=
'2.0'
:
import
nni.compression.tensorflow
as
tf_compressor
import
nni.compression.tensorflow
as
tf_compressor
def
get_tf_model
():
def
get_tf_model
():
model
=
tf
.
keras
.
models
.
Sequential
([
model
=
tf
.
keras
.
models
.
Sequential
([
tf
.
keras
.
layers
.
Conv2D
(
filters
=
5
,
kernel_size
=
7
,
input_shape
=
[
28
,
28
,
1
],
activation
=
'relu'
,
padding
=
"SAME"
),
tf
.
keras
.
layers
.
Conv2D
(
filters
=
5
,
kernel_size
=
7
,
input_shape
=
[
28
,
28
,
1
],
activation
=
'relu'
,
padding
=
"SAME"
),
...
@@ -24,38 +25,45 @@ def get_tf_model():
...
@@ -24,38 +25,45 @@ def get_tf_model():
metrics
=
[
"accuracy"
])
metrics
=
[
"accuracy"
])
return
model
return
model
class
TorchModel
(
torch
.
nn
.
Module
):
class
TorchModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
5
,
5
,
1
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
5
,
5
,
1
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
5
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
5
,
10
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
5
,
10
,
5
,
1
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
10
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
10
,
100
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
10
,
100
)
self
.
fc2
=
torch
.
nn
.
Linear
(
100
,
10
)
self
.
fc2
=
torch
.
nn
.
Linear
(
100
,
10
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
))
)
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
conv2
(
x
))
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
))
)
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
10
)
x
=
x
.
view
(
-
1
,
4
*
4
*
10
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
tf2
(
func
):
def
tf2
(
func
):
def
test_tf2_func
(
*
args
):
def
test_tf2_func
(
*
args
):
if
tf
.
__version__
>=
'2.0'
:
if
tf
.
__version__
>=
'2.0'
:
func
(
*
args
)
func
(
*
args
)
return
test_tf2_func
return
test_tf2_func
k1
=
[[
1
]
*
3
]
*
3
k2
=
[[
2
]
*
3
]
*
3
k1
=
[[
1
]
*
3
]
*
3
k3
=
[[
3
]
*
3
]
*
3
k2
=
[[
2
]
*
3
]
*
3
k4
=
[[
4
]
*
3
]
*
3
k3
=
[[
3
]
*
3
]
*
3
k5
=
[[
5
]
*
3
]
*
3
k4
=
[[
4
]
*
3
]
*
3
k5
=
[[
5
]
*
3
]
*
3
w
=
[[
k1
,
k2
,
k3
,
k4
,
k5
]]
*
10
w
=
[[
k1
,
k2
,
k3
,
k4
,
k5
]]
*
10
class
CompressorTestCase
(
TestCase
):
class
CompressorTestCase
(
TestCase
):
def
test_torch_level_pruner
(
self
):
def
test_torch_level_pruner
(
self
):
model
=
TorchModel
()
model
=
TorchModel
()
...
@@ -74,7 +82,7 @@ class CompressorTestCase(TestCase):
...
@@ -74,7 +82,7 @@ class CompressorTestCase(TestCase):
'quant_bits'
:
{
'quant_bits'
:
{
'weight'
:
8
,
'weight'
:
8
,
},
},
'op_types'
:[
'Conv2d'
,
'Linear'
]
'op_types'
:
[
'Conv2d'
,
'Linear'
]
}]
}]
torch_compressor
.
NaiveQuantizer
(
model
,
configure_list
).
compress
()
torch_compressor
.
NaiveQuantizer
(
model
,
configure_list
).
compress
()
...
@@ -133,6 +141,73 @@ class CompressorTestCase(TestCase):
...
@@ -133,6 +141,73 @@ class CompressorTestCase(TestCase):
assert
all
(
masks
.
sum
((
0
,
2
,
3
))
==
np
.
array
([
90.
,
0.
,
0.
,
0.
,
90.
]))
assert
all
(
masks
.
sum
((
0
,
2
,
3
))
==
np
.
array
([
90.
,
0.
,
0.
,
0.
,
90.
]))
def
test_torch_l1filter_pruner
(
self
):
"""
Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
PRUNING FILTERS FOR EFFICIENT CONVNETS,
https://arxiv.org/abs/1608.08710
So if sparsity is 0.2, the expected masks should mask out filter 0, this can be verified through:
`all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))`
If sparsity is 0.6, the expected masks should mask out filter 0,1,2, this can be verified through:
`all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))`
"""
w
=
np
.
array
([
np
.
zeros
((
3
,
3
,
3
)),
np
.
ones
((
3
,
3
,
3
)),
np
.
ones
((
3
,
3
,
3
))
*
2
,
np
.
ones
((
3
,
3
,
3
))
*
3
,
np
.
ones
((
3
,
3
,
3
))
*
4
])
model
=
TorchModel
()
config_list
=
[{
'sparsity'
:
0.2
,
'op_names'
:
[
'conv1'
]},
{
'sparsity'
:
0.6
,
'op_names'
:
[
'conv2'
]}]
pruner
=
torch_compressor
.
L1FilterPruner
(
model
,
config_list
)
model
.
conv1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv1'
,
model
.
conv1
)
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
])
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'conv2'
,
model
.
conv2
)
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
1
])
assert
all
(
torch
.
sum
(
mask1
,
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
0.
,
27.
,
27.
,
27.
,
27.
]))
assert
all
(
torch
.
sum
(
mask2
,
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
27.
,
27.
]))
def
test_torch_slim_pruner
(
self
):
"""
Scale factors with minimum l1 norm in the BN layers are pruned in this paper:
Learning Efficient Convolutional Networks through Network Slimming,
https://arxiv.org/pdf/1708.06519.pdf
So if sparsity is 0.2, the expected masks should mask out channel 0, this can be verified through:
`all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))`
`all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))`
If sparsity is 0.6, the expected masks should mask out channel 0,1,2, this can be verified through:
`all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))`
`all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))`
"""
w
=
np
.
array
([
0
,
1
,
2
,
3
,
4
])
model
=
TorchModel
()
config_list
=
[{
'sparsity'
:
0.2
,
'op_types'
:
[
'BatchNorm2d'
]}]
model
.
bn1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
bn2
.
weight
.
data
=
torch
.
tensor
(
-
w
).
float
()
pruner
=
torch_compressor
.
SlimPruner
(
model
,
config_list
)
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn1'
,
model
.
bn1
)
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
])
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn2'
,
model
.
bn2
)
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
0
])
assert
all
(
mask1
.
numpy
()
==
np
.
array
([
0.
,
1.
,
1.
,
1.
,
1.
]))
assert
all
(
mask2
.
numpy
()
==
np
.
array
([
0.
,
1.
,
1.
,
1.
,
1.
]))
config_list
=
[{
'sparsity'
:
0.6
,
'op_types'
:
[
'BatchNorm2d'
]}]
model
.
bn1
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
bn2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
pruner
=
torch_compressor
.
SlimPruner
(
model
,
config_list
)
layer1
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn1'
,
model
.
bn1
)
mask1
=
pruner
.
calc_mask
(
layer1
,
config_list
[
0
])
layer2
=
torch_compressor
.
compressor
.
LayerInfo
(
'bn2'
,
model
.
bn2
)
mask2
=
pruner
.
calc_mask
(
layer2
,
config_list
[
0
])
assert
all
(
mask1
.
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
1.
,
1.
]))
assert
all
(
mask2
.
numpy
()
==
np
.
array
([
0.
,
0.
,
0.
,
1.
,
1.
]))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment