"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "4d19cbacccda71c2748bbd8c41437e7afd7e6f9c"
Unverified Commit 259aee75 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

Update tf pruner example (#3708)

parent 42337dc0
......@@ -10,9 +10,9 @@ import argparse
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import (Conv2D, Dense, Dropout, Flatten, MaxPool2D)
from tensorflow.keras.layers import (Conv2D, Dense, Dropout, Flatten, MaxPool2D, BatchNormalization)
from nni.algorithms.compression.tensorflow.pruning import LevelPruner
from nni.algorithms.compression.tensorflow.pruning import LevelPruner, SlimPruner
class LeNet(Model):
"""
......@@ -34,8 +34,10 @@ class LeNet(Model):
super().__init__()
self.conv1 = Conv2D(filters=32, kernel_size=conv_size, activation='relu')
self.pool1 = MaxPool2D(pool_size=2)
self.bn1 = BatchNormalization()
self.conv2 = Conv2D(filters=64, kernel_size=conv_size, activation='relu')
self.pool2 = MaxPool2D(pool_size=2)
self.bn2 = BatchNormalization()
self.flatten = Flatten()
self.fc1 = Dense(units=hidden_size, activation='relu')
self.dropout = Dropout(rate=dropout_rate)
......@@ -45,8 +47,10 @@ class LeNet(Model):
"""Override ``Model.call`` to build LeNet-5 model."""
x = self.conv1(x)
x = self.pool1(x)
x = self.bn1(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.bn2(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.dropout(x)
......@@ -85,12 +89,29 @@ def main(args):
model = LeNet()
print('start training')
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, decay=1e-4)
model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
if args.pruner_name == 'slim':
def slim_loss(y_true, y_pred):
loss_1 = tf.keras.losses.sparse_categorical_crossentropy(y_true=y_true, y_pred=y_pred)
weight_list = []
for layer in [model.bn1, model.bn2]:
weight_list.append([w for w in layer.weights if '/gamma:' in w.name][0].read_value())
loss_2 = 0.0001 * tf.reduce_sum([tf.reduce_sum(tf.abs(w)) for w in weight_list])
return loss_1 + loss_2
model.compile(
optimizer=optimizer,
loss=slim_loss,
metrics=['accuracy']
)
else:
model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
model.fit(
train_set[0],
train_set[1],
......@@ -103,13 +124,19 @@ def main(args):
optimizer_finetune = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9, decay=1e-4)
# create_pruner
prune_config = [{
'sparsity': args.sparsity,
'op_types': ['default'],
}]
pruner = LevelPruner(model, prune_config)
# pruner = create_pruner(model, args.pruner_name)
if args.pruner_name == 'level':
prune_config = [{
'sparsity': args.sparsity,
'op_types': ['default'],
}]
pruner = LevelPruner(model, prune_config)
elif args.pruner_name == 'slim':
prune_config = [{
'sparsity': args.sparsity,
'op_types': ['BatchNormalization'],
}]
pruner = SlimPruner(model, prune_config)
model = pruner.compress()
model.compile(
......@@ -131,7 +158,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pruner_name', type=str, default='level')
parser.add_argument('--pruner_name', type=str, default='level', choices=['level', 'slim'])
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--pretrain_epochs', type=int, default=10)
parser.add_argument('--prune_epochs', type=int, default=10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment