Unverified Commit 259aee75 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

Update tf pruner example (#3708)

parent 42337dc0
...@@ -10,9 +10,9 @@ import argparse ...@@ -10,9 +10,9 @@ import argparse
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import Model from tensorflow.keras import Model
from tensorflow.keras.layers import (Conv2D, Dense, Dropout, Flatten, MaxPool2D) from tensorflow.keras.layers import (Conv2D, Dense, Dropout, Flatten, MaxPool2D, BatchNormalization)
from nni.algorithms.compression.tensorflow.pruning import LevelPruner from nni.algorithms.compression.tensorflow.pruning import LevelPruner, SlimPruner
class LeNet(Model): class LeNet(Model):
""" """
...@@ -34,8 +34,10 @@ class LeNet(Model): ...@@ -34,8 +34,10 @@ class LeNet(Model):
super().__init__() super().__init__()
self.conv1 = Conv2D(filters=32, kernel_size=conv_size, activation='relu') self.conv1 = Conv2D(filters=32, kernel_size=conv_size, activation='relu')
self.pool1 = MaxPool2D(pool_size=2) self.pool1 = MaxPool2D(pool_size=2)
self.bn1 = BatchNormalization()
self.conv2 = Conv2D(filters=64, kernel_size=conv_size, activation='relu') self.conv2 = Conv2D(filters=64, kernel_size=conv_size, activation='relu')
self.pool2 = MaxPool2D(pool_size=2) self.pool2 = MaxPool2D(pool_size=2)
self.bn2 = BatchNormalization()
self.flatten = Flatten() self.flatten = Flatten()
self.fc1 = Dense(units=hidden_size, activation='relu') self.fc1 = Dense(units=hidden_size, activation='relu')
self.dropout = Dropout(rate=dropout_rate) self.dropout = Dropout(rate=dropout_rate)
...@@ -45,8 +47,10 @@ class LeNet(Model): ...@@ -45,8 +47,10 @@ class LeNet(Model):
"""Override ``Model.call`` to build LeNet-5 model.""" """Override ``Model.call`` to build LeNet-5 model."""
x = self.conv1(x) x = self.conv1(x)
x = self.pool1(x) x = self.pool1(x)
x = self.bn1(x)
x = self.conv2(x) x = self.conv2(x)
x = self.pool2(x) x = self.pool2(x)
x = self.bn2(x)
x = self.flatten(x) x = self.flatten(x)
x = self.fc1(x) x = self.fc1(x)
x = self.dropout(x) x = self.dropout(x)
...@@ -85,12 +89,29 @@ def main(args): ...@@ -85,12 +89,29 @@ def main(args):
model = LeNet() model = LeNet()
print('start training') print('start training')
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, decay=1e-4) optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, decay=1e-4)
model.compile(
optimizer=optimizer, if args.pruner_name == 'slim':
loss='sparse_categorical_crossentropy', def slim_loss(y_true, y_pred):
metrics=['accuracy'] loss_1 = tf.keras.losses.sparse_categorical_crossentropy(y_true=y_true, y_pred=y_pred)
) weight_list = []
for layer in [model.bn1, model.bn2]:
weight_list.append([w for w in layer.weights if '/gamma:' in w.name][0].read_value())
loss_2 = 0.0001 * tf.reduce_sum([tf.reduce_sum(tf.abs(w)) for w in weight_list])
return loss_1 + loss_2
model.compile(
optimizer=optimizer,
loss=slim_loss,
metrics=['accuracy']
)
else:
model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
model.fit( model.fit(
train_set[0], train_set[0],
train_set[1], train_set[1],
...@@ -103,13 +124,19 @@ def main(args): ...@@ -103,13 +124,19 @@ def main(args):
optimizer_finetune = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9, decay=1e-4) optimizer_finetune = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9, decay=1e-4)
# create_pruner # create_pruner
prune_config = [{ if args.pruner_name == 'level':
'sparsity': args.sparsity, prune_config = [{
'op_types': ['default'], 'sparsity': args.sparsity,
}] 'op_types': ['default'],
}]
pruner = LevelPruner(model, prune_config) pruner = LevelPruner(model, prune_config)
# pruner = create_pruner(model, args.pruner_name) elif args.pruner_name == 'slim':
prune_config = [{
'sparsity': args.sparsity,
'op_types': ['BatchNormalization'],
}]
pruner = SlimPruner(model, prune_config)
model = pruner.compress() model = pruner.compress()
model.compile( model.compile(
...@@ -131,7 +158,7 @@ def main(args): ...@@ -131,7 +158,7 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--pruner_name', type=str, default='level') parser.add_argument('--pruner_name', type=str, default='level', choices=['level', 'slim'])
parser.add_argument('--batch-size', type=int, default=256) parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--pretrain_epochs', type=int, default=10) parser.add_argument('--pretrain_epochs', type=int, default=10)
parser.add_argument('--prune_epochs', type=int, default=10) parser.add_argument('--prune_epochs', type=int, default=10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment