# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. ''' NNI example for quick start of pruning. In this example, we use level pruner to prune the LeNet on MNIST. ''' import argparse import tensorflow as tf from tensorflow.keras import Model from tensorflow.keras.layers import (Conv2D, Dense, Dropout, Flatten, MaxPool2D, BatchNormalization) from nni.algorithms.compression.tensorflow.pruning import LevelPruner, SlimPruner class LeNet(Model): """ LeNet-5 Model with customizable hyper-parameters """ def __init__(self, conv_size=3, hidden_size=32, dropout_rate=0.5): """ Initialize hyper-parameters. Parameters ---------- conv_size : int Kernel size of convolutional layers. hidden_size : int Dimensionality of last hidden layer. dropout_rate : float Dropout rate between two fully connected (dense) layers, to prevent co-adaptation. """ super().__init__() self.conv1 = Conv2D(filters=32, kernel_size=conv_size, activation='relu') self.pool1 = MaxPool2D(pool_size=2) self.bn1 = BatchNormalization() self.conv2 = Conv2D(filters=64, kernel_size=conv_size, activation='relu') self.pool2 = MaxPool2D(pool_size=2) self.bn2 = BatchNormalization() self.flatten = Flatten() self.fc1 = Dense(units=hidden_size, activation='relu') self.dropout = Dropout(rate=dropout_rate) self.fc2 = Dense(units=10, activation='softmax') def call(self, x): """Override ``Model.call`` to build LeNet-5 model.""" x = self.conv1(x) x = self.pool1(x) x = self.bn1(x) x = self.conv2(x) x = self.pool2(x) x = self.bn2(x) x = self.flatten(x) x = self.fc1(x) x = self.dropout(x) return self.fc2(x) def get_dataset(dataset_name='mnist'): assert dataset_name == 'mnist' (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train = x_train[..., tf.newaxis] / 255.0 x_test = x_test[..., tf.newaxis] / 255.0 return (x_train, y_train), (x_test, y_test) # def create_model(model_name='naive'): # assert model_name == 'naive' # return tf.keras.Sequential([ # tf.keras.layers.Conv2D(filters=20, kernel_size=5), # tf.keras.layers.BatchNormalization(), # tf.keras.layers.ReLU(), # tf.keras.layers.MaxPool2D(pool_size=2), # tf.keras.layers.Conv2D(filters=20, kernel_size=5), # tf.keras.layers.BatchNormalization(), # tf.keras.layers.ReLU(), # tf.keras.layers.MaxPool2D(pool_size=2), # tf.keras.layers.Flatten(), # tf.keras.layers.Dense(units=500), # tf.keras.layers.ReLU(), # tf.keras.layers.Dense(units=10), # tf.keras.layers.Softmax() # ]) def main(args): train_set, test_set = get_dataset('mnist') model = LeNet() print('start training') optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, decay=1e-4) if args.pruner_name == 'slim': def slim_loss(y_true, y_pred): loss_1 = tf.keras.losses.sparse_categorical_crossentropy(y_true=y_true, y_pred=y_pred) weight_list = [] for layer in [model.bn1, model.bn2]: weight_list.append([w for w in layer.weights if '/gamma:' in w.name][0].read_value()) loss_2 = 0.0001 * tf.reduce_sum([tf.reduce_sum(tf.abs(w)) for w in weight_list]) return loss_1 + loss_2 model.compile( optimizer=optimizer, loss=slim_loss, metrics=['accuracy'] ) else: model.compile( optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'] ) model.fit( train_set[0], train_set[1], batch_size=args.batch_size, epochs=args.pretrain_epochs, validation_data=test_set ) print('start pruning') optimizer_finetune = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9, decay=1e-4) # create_pruner if args.pruner_name == 'level': prune_config = [{ 'sparsity': args.sparsity, 'op_types': ['default'], }] pruner = LevelPruner(model, prune_config) elif args.pruner_name == 'slim': prune_config = [{ 'sparsity': args.sparsity, 'op_types': ['BatchNormalization'], }] pruner = SlimPruner(model, prune_config) model = pruner.compress() model.compile( optimizer=optimizer_finetune, loss='sparse_categorical_crossentropy', metrics=['accuracy'], run_eagerly=True # NOTE: Important, model compression does not work in graph mode! ) # fine-tuning model.fit( train_set[0], train_set[1], batch_size=args.batch_size, epochs=args.prune_epochs, validation_data=test_set ) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--pruner_name', type=str, default='level', choices=['level', 'slim']) parser.add_argument('--batch-size', type=int, default=256) parser.add_argument('--pretrain_epochs', type=int, default=10) parser.add_argument('--prune_epochs', type=int, default=10) parser.add_argument('--sparsity', type=float, default=0.5) args = parser.parse_args() main(args)