Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
259aee75
"docs/archive_en_US/Tutorial/SetupNniDeveloperEnvironment.md" did not exist on "7e35d32e2987493838779826155f7434bc30b81c"
Unverified
Commit
259aee75
authored
Jun 02, 2021
by
J-shang
Committed by
GitHub
Jun 02, 2021
Browse files
Update tf pruner example (#3708)
parent
42337dc0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
15 deletions
+42
-15
examples/model_compress/pruning/naive_prune_tf.py
examples/model_compress/pruning/naive_prune_tf.py
+42
-15
No files found.
examples/model_compress/pruning/naive_prune_tf.py
View file @
259aee75
...
...
@@ -10,9 +10,9 @@ import argparse
import
tensorflow
as
tf
from
tensorflow.keras
import
Model
from
tensorflow.keras.layers
import
(
Conv2D
,
Dense
,
Dropout
,
Flatten
,
MaxPool2D
)
from
tensorflow.keras.layers
import
(
Conv2D
,
Dense
,
Dropout
,
Flatten
,
MaxPool2D
,
BatchNormalization
)
from
nni.algorithms.compression.tensorflow.pruning
import
LevelPruner
from
nni.algorithms.compression.tensorflow.pruning
import
LevelPruner
,
SlimPruner
class
LeNet
(
Model
):
"""
...
...
@@ -34,8 +34,10 @@ class LeNet(Model):
super
().
__init__
()
self
.
conv1
=
Conv2D
(
filters
=
32
,
kernel_size
=
conv_size
,
activation
=
'relu'
)
self
.
pool1
=
MaxPool2D
(
pool_size
=
2
)
self
.
bn1
=
BatchNormalization
()
self
.
conv2
=
Conv2D
(
filters
=
64
,
kernel_size
=
conv_size
,
activation
=
'relu'
)
self
.
pool2
=
MaxPool2D
(
pool_size
=
2
)
self
.
bn2
=
BatchNormalization
()
self
.
flatten
=
Flatten
()
self
.
fc1
=
Dense
(
units
=
hidden_size
,
activation
=
'relu'
)
self
.
dropout
=
Dropout
(
rate
=
dropout_rate
)
...
...
@@ -45,8 +47,10 @@ class LeNet(Model):
"""Override ``Model.call`` to build LeNet-5 model."""
x
=
self
.
conv1
(
x
)
x
=
self
.
pool1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
pool2
(
x
)
x
=
self
.
bn2
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
dropout
(
x
)
...
...
@@ -85,12 +89,29 @@ def main(args):
model
=
LeNet
()
print
(
'start training'
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
0.1
,
momentum
=
0.9
,
decay
=
1e-4
)
if
args
.
pruner_name
==
'slim'
:
def
slim_loss
(
y_true
,
y_pred
):
loss_1
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
y_true
=
y_true
,
y_pred
=
y_pred
)
weight_list
=
[]
for
layer
in
[
model
.
bn1
,
model
.
bn2
]:
weight_list
.
append
([
w
for
w
in
layer
.
weights
if
'/gamma:'
in
w
.
name
][
0
].
read_value
())
loss_2
=
0.0001
*
tf
.
reduce_sum
([
tf
.
reduce_sum
(
tf
.
abs
(
w
))
for
w
in
weight_list
])
return
loss_1
+
loss_2
model
.
compile
(
optimizer
=
optimizer
,
loss
=
slim_loss
,
metrics
=
[
'accuracy'
]
)
else
:
model
.
compile
(
optimizer
=
optimizer
,
loss
=
'sparse_categorical_crossentropy'
,
metrics
=
[
'accuracy'
]
)
model
.
fit
(
train_set
[
0
],
train_set
[
1
],
...
...
@@ -103,13 +124,19 @@ def main(args):
optimizer_finetune
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
0.001
,
momentum
=
0.9
,
decay
=
1e-4
)
# create_pruner
if
args
.
pruner_name
==
'level'
:
prune_config
=
[{
'sparsity'
:
args
.
sparsity
,
'op_types'
:
[
'default'
],
}]
pruner
=
LevelPruner
(
model
,
prune_config
)
# pruner = create_pruner(model, args.pruner_name)
elif
args
.
pruner_name
==
'slim'
:
prune_config
=
[{
'sparsity'
:
args
.
sparsity
,
'op_types'
:
[
'BatchNormalization'
],
}]
pruner
=
SlimPruner
(
model
,
prune_config
)
model
=
pruner
.
compress
()
model
.
compile
(
...
...
@@ -131,7 +158,7 @@ def main(args):
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--pruner_name'
,
type
=
str
,
default
=
'level'
)
parser
.
add_argument
(
'--pruner_name'
,
type
=
str
,
default
=
'level'
,
choices
=
[
'level'
,
'slim'
]
)
parser
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
'--pretrain_epochs'
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
'--prune_epochs'
,
type
=
int
,
default
=
10
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment