model_prune_tf.py 3.34 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import argparse

import tensorflow as tf

import nni.compression.tensorflow

prune_config = {
    'level': {
        'dataset_name': 'mnist',
        'model_name': 'naive',
        'pruner_class': nni.compression.tensorflow.LevelPruner,
        'config_list': [{
            'sparsity': 0.9,
            'op_types': ['default'],
        }]
    },
}


def get_dataset(dataset_name='mnist'):
    assert dataset_name == 'mnist'

    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train = x_train[..., tf.newaxis] / 255.0
    x_test = x_test[..., tf.newaxis] / 255.0
    return (x_train, y_train), (x_test, y_test)


def create_model(model_name='naive'):
    assert model_name == 'naive'
liuzhe-lz's avatar
liuzhe-lz committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    return NaiveModel()

class NaiveModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.seq_layers = [
            tf.keras.layers.Conv2D(filters=20, kernel_size=5),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.MaxPool2D(pool_size=2),
            tf.keras.layers.Conv2D(filters=20, kernel_size=5),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.MaxPool2D(pool_size=2),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(units=500),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dense(units=10),
            tf.keras.layers.Softmax()
        ]

    def call(self, x):
        for layer in self.seq_layers:
            x = layer(x)
        return x
liuzhe-lz's avatar
liuzhe-lz committed
56
57
58
59
60
61
62
63
64
65
66
67


def create_pruner(model, pruner_name):
    pruner_class = prune_config[pruner_name]['pruner_class']
    config_list = prune_config[pruner_name]['config_list']
    return pruner_class(model, config_list)


def main(args):
    model_name = prune_config[args.pruner_name]['model_name']
    dataset_name = prune_config[args.pruner_name]['dataset_name']
    train_set, test_set = get_dataset(dataset_name)
liuzhe-lz's avatar
liuzhe-lz committed
68
    model = create_model(model_name)
liuzhe-lz's avatar
liuzhe-lz committed
69
70

    print('start training')
liuzhe-lz's avatar
liuzhe-lz committed
71
72
73
74
75
76
77
78
79
80
81
82
83
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, decay=1e-4)
    model.compile(
        optimizer=optimizer,
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    model.fit(
        train_set[0],
        train_set[1],
        batch_size=args.batch_size,
        epochs=args.pretrain_epochs,
        validation_data=test_set
    )
liuzhe-lz's avatar
liuzhe-lz committed
84
85
86
87
88

    print('start model pruning')
    optimizer_finetune = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9, decay=1e-4)
    pruner = create_pruner(model, args.pruner_name)
    model = pruner.compress()
liuzhe-lz's avatar
liuzhe-lz committed
89
90
91
92
93
94
95
96
97
98
99
100
101
    model.compile(
        optimizer=optimizer_finetune,
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'],
        run_eagerly=True  # NOTE: Important, model compression does not work in graph mode!
    )
    model.fit(
        train_set[0],
        train_set[1],
        batch_size=args.batch_size,
        epochs=args.prune_epochs,
        validation_data=test_set
    )
liuzhe-lz's avatar
liuzhe-lz committed
102
103
104
105
106
107
108
109
110
111
112


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--pruner_name', type=str, default='level')
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--pretrain_epochs', type=int, default=10)
    parser.add_argument('--prune_epochs', type=int, default=10)

    args = parser.parse_args()
    main(args)