Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c7d58033
"test/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "69de8c4bdffee9e4ab94b78b570d2c8b1095ace4"
Commit
c7d58033
authored
Feb 10, 2020
by
chicm-ms
Committed by
GitHub
Feb 10, 2020
Browse files
Fix pruners for DataParallel support (#2003)
parent
4e21e721
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
35 deletions
+44
-35
examples/model_compress/fpgm_torch_mnist.py
examples/model_compress/fpgm_torch_mnist.py
+13
-6
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+1
-2
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+26
-22
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
...pynni/nni/compression/torch/weight_rank_filter_pruners.py
+4
-3
src/sdk/pynni/tests/test_compressor.py
src/sdk/pynni/tests/test_compressor.py
+0
-2
No files found.
examples/model_compress/fpgm_torch_mnist.py
View file @
c7d58033
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
from
nni.compression.torch
import
FPGMPruner
from
nni.compression.torch
import
FPGMPruner
...
@@ -6,10 +7,10 @@ from nni.compression.torch import FPGMPruner
...
@@ -6,10 +7,10 @@ from nni.compression.torch import FPGMPruner
class
Mnist
(
torch
.
nn
.
Module
):
class
Mnist
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv1
=
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc1
=
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
self
.
fc2
=
nn
.
Linear
(
500
,
10
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
F
.
relu
(
self
.
conv1
(
x
))
...
@@ -27,8 +28,14 @@ class Mnist(torch.nn.Module):
...
@@ -27,8 +28,14 @@ class Mnist(torch.nn.Module):
return
num_zero_filters
,
num_filters
,
float
(
num_zero_filters
)
/
num_filters
return
num_zero_filters
,
num_filters
,
float
(
num_zero_filters
)
/
num_filters
def
print_conv_filter_sparsity
(
self
):
def
print_conv_filter_sparsity
(
self
):
if
isinstance
(
self
.
conv1
,
nn
.
Conv2d
):
conv1_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv1
)
conv1_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv1
)
conv2_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv2
)
conv2_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv2
)
else
:
# self.conv1 is wrapped as PrunerModuleWrapper
conv1_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv1
.
module
)
conv2_data
=
self
.
_get_conv_weight_sparsity
(
self
.
conv2
.
module
)
print
(
'conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'
.
format
(
conv1_data
[
0
],
conv1_data
[
1
],
conv1_data
[
2
]))
print
(
'conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'
.
format
(
conv1_data
[
0
],
conv1_data
[
1
],
conv1_data
[
2
]))
print
(
'conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'
.
format
(
conv2_data
[
0
],
conv2_data
[
1
],
conv2_data
[
2
]))
print
(
'conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'
.
format
(
conv2_data
[
0
],
conv2_data
[
1
],
conv2_data
[
2
]))
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
c7d58033
...
@@ -246,7 +246,7 @@ class PrunerModuleWrapper(torch.nn.Module):
...
@@ -246,7 +246,7 @@ class PrunerModuleWrapper(torch.nn.Module):
self
.
module
.
weight
.
data
=
self
.
module
.
weight
.
data
.
mul_
(
self
.
weight_mask
)
self
.
module
.
weight
.
data
=
self
.
module
.
weight
.
data
.
mul_
(
self
.
weight_mask
)
# apply mask to bias
# apply mask to bias
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
if
mask
is
not
None
:
if
mask
is
not
None
and
'bias'
in
mask
:
self
.
bias_mask
.
copy_
(
mask
[
'bias'
])
self
.
bias_mask
.
copy_
(
mask
[
'bias'
])
self
.
module
.
bias
.
data
=
self
.
module
.
bias
.
data
.
mul_
(
self
.
bias_mask
)
self
.
module
.
bias
.
data
=
self
.
module
.
bias
.
data
.
mul_
(
self
.
bias_mask
)
return
self
.
module
(
*
inputs
)
return
self
.
module
(
*
inputs
)
...
@@ -565,4 +565,3 @@ def _check_weight(module):
...
@@ -565,4 +565,3 @@ def _check_weight(module):
return
isinstance
(
module
.
weight
.
data
,
torch
.
Tensor
)
return
isinstance
(
module
.
weight
.
data
,
torch
.
Tensor
)
except
AttributeError
:
except
AttributeError
:
return
False
return
False
\ No newline at end of file
src/sdk/pynni/nni/compression/torch/pruners.py
View file @
c7d58033
...
@@ -83,17 +83,20 @@ class AGP_Pruner(Pruner):
...
@@ -83,17 +83,20 @@ class AGP_Pruner(Pruner):
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
now_epoch
=
0
self
.
now_epoch
=
0
self
.
if_init_list
=
{}
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
"""
Calculate the mask of given layer
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
Parameters
----------
----------
layer : LayerInfo
layer : LayerInfo
the layer to instrument the compression operation
the layer to instrument the compression operation
config : dict
config : dict
layer's pruning config
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
Returns
-------
-------
dict
dict
...
@@ -101,24 +104,26 @@ class AGP_Pruner(Pruner):
...
@@ -101,24 +104,26 @@ class AGP_Pruner(Pruner):
"""
"""
weight
=
layer
.
module
.
weight
.
data
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
freq
=
config
.
get
(
'frequency'
,
1
)
if
self
.
now_epoch
>=
start_epoch
and
self
.
if_init_list
.
get
(
op_name
,
True
)
\
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
:
if_calculated
=
kwargs
[
"if_calculated"
]
mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
if
if_calculated
:
return
None
if
not
(
self
.
now_epoch
>=
start_epoch
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
):
return
None
mask
=
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)}
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
return
mask
return
mask
# if we want to generate new mask, we should update weigth first
# if we want to generate new mask, we should update weigth first
w_abs
=
weight
.
abs
()
*
mask
w_abs
=
weight
.
abs
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
new_mask
=
{
'weight'
:
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)}
new_mask
=
{
'weight'
:
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)}
self
.
mask_dict
.
update
({
op_name
:
new_mask
})
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
self
.
if_init_list
.
update
({
op_name
:
False
})
else
:
new_mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
return
new_mask
return
new_mask
def
compute_target_sparsity
(
self
,
config
):
def
compute_target_sparsity
(
self
,
config
):
...
@@ -164,9 +169,8 @@ class AGP_Pruner(Pruner):
...
@@ -164,9 +169,8 @@ class AGP_Pruner(Pruner):
if
epoch
>
0
:
if
epoch
>
0
:
self
.
now_epoch
=
epoch
self
.
now_epoch
=
epoch
for
k
in
self
.
if_init_list
.
keys
():
for
wrapper
in
self
.
get_modules_wrapper
():
self
.
if_init_list
[
k
]
=
True
wrapper
.
registered_buffers
[
'if_calculated'
].
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
class
SlimPruner
(
Pruner
):
class
SlimPruner
(
Pruner
):
"""
"""
...
...
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
View file @
c7d58033
...
@@ -27,7 +27,7 @@ class WeightRankFilterPruner(Pruner):
...
@@ -27,7 +27,7 @@ class WeightRankFilterPruner(Pruner):
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
False
))
# pylint: disable=not-callable
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
...
@@ -69,7 +69,7 @@ class WeightRankFilterPruner(Pruner):
...
@@ -69,7 +69,7 @@ class WeightRankFilterPruner(Pruner):
return
mask
return
mask
mask
=
self
.
get_mask
(
mask
,
weight
,
num_prune
)
mask
=
self
.
get_mask
(
mask
,
weight
,
num_prune
)
finally
:
finally
:
if_calculated
.
copy_
(
torch
.
tensor
(
True
))
# pylint: disable=not-callable
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
return
mask
return
mask
...
@@ -257,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner):
...
@@ -257,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner):
return
x
.
sum
()
return
x
.
sum
()
def
update_epoch
(
self
,
epoch
):
def
update_epoch
(
self
,
epoch
):
self
.
mask_calculated_ops
=
set
()
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
registered_buffers
[
'if_calculated'
].
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
src/sdk/pynni/tests/test_compressor.py
View file @
c7d58033
...
@@ -138,7 +138,6 @@ class CompressorTestCase(TestCase):
...
@@ -138,7 +138,6 @@ class CompressorTestCase(TestCase):
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
0
],
if_calculated
=
torch
.
tensor
(
0
))
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
0
],
if_calculated
=
torch
.
tensor
(
0
))
assert
all
(
torch
.
sum
(
masks
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
45.
,
45.
,
45.
,
45.
,
0.
,
0.
,
45.
,
45.
,
45.
,
45.
]))
assert
all
(
torch
.
sum
(
masks
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
45.
,
45.
,
45.
,
45.
,
0.
,
0.
,
45.
,
45.
,
45.
,
45.
]))
pruner
.
update_epoch
(
1
)
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
model
.
conv2
.
weight
.
data
=
torch
.
tensor
(
w
).
float
()
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
],
if_calculated
=
torch
.
tensor
(
0
))
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
],
if_calculated
=
torch
.
tensor
(
0
))
assert
all
(
torch
.
sum
(
masks
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
45.
,
45.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
45.
,
45.
]))
assert
all
(
torch
.
sum
(
masks
[
'weight'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
45.
,
45.
,
0.
,
0.
,
0.
,
0.
,
0.
,
0.
,
45.
,
45.
]))
...
@@ -159,7 +158,6 @@ class CompressorTestCase(TestCase):
...
@@ -159,7 +158,6 @@ class CompressorTestCase(TestCase):
assert
all
(
masks
.
sum
((
1
))
==
np
.
array
([
45.
,
45.
,
45.
,
45.
,
0.
,
0.
,
45.
,
45.
,
45.
,
45.
]))
assert
all
(
masks
.
sum
((
1
))
==
np
.
array
([
45.
,
45.
,
45.
,
45.
,
0.
,
0.
,
45.
,
45.
,
45.
,
45.
]))
pruner
.
update_epoch
(
1
)
model
.
layers
[
2
].
set_weights
([
weights
[
0
],
weights
[
1
].
numpy
()])
model
.
layers
[
2
].
set_weights
([
weights
[
0
],
weights
[
1
].
numpy
()])
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
]).
numpy
()
masks
=
pruner
.
calc_mask
(
layer
,
config_list
[
1
]).
numpy
()
masks
=
masks
.
reshape
((
-
1
,
masks
.
shape
[
-
1
])).
transpose
([
1
,
0
])
masks
=
masks
.
reshape
((
-
1
,
masks
.
shape
[
-
1
])).
transpose
([
1
,
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment