model_prune_tf.py 2.85 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse

import tensorflow as tf

import nni.compression.tensorflow

prune_config = {
    'level': {
        'dataset_name': 'mnist',
        'model_name': 'naive',
        'pruner_class': nni.compression.tensorflow.LevelPruner,
        'config_list': [{
            'sparsity': 0.9,
            'op_types': ['default'],
        }]
    },
}


def get_dataset(dataset_name='mnist'):
    assert dataset_name == 'mnist'

    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train = x_train[..., tf.newaxis] / 255.0
    x_test = x_test[..., tf.newaxis] / 255.0
    return (x_train, y_train), (x_test, y_test)


def create_model(model_name='naive'):
    assert model_name == 'naive'
    return tf.keras.Sequential([
        tf.keras.layers.Conv2D(filters=20, kernel_size=5),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.ReLU(),
        tf.keras.layers.MaxPool2D(pool_size=2),
        tf.keras.layers.Conv2D(filters=20, kernel_size=5),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.ReLU(),
        tf.keras.layers.MaxPool2D(pool_size=2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(units=500),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Dense(units=10),
        tf.keras.layers.Softmax()
    ])


def create_pruner(model, pruner_name):
    pruner_class = prune_config[pruner_name]['pruner_class']
    config_list = prune_config[pruner_name]['config_list']
    return pruner_class(model, config_list)


def main(args):
    model_name = prune_config[args.pruner_name]['model_name']
    dataset_name = prune_config[args.pruner_name]['dataset_name']
    train_set, test_set = get_dataset(dataset_name)
    model  = create_model(model_name)

    optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, decay=1e-4)
    model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    print('start training')
    model.fit(train_set[0], train_set[1], batch_size=args.batch_size, epochs=args.pretrain_epochs, validation_data=test_set)

    print('start model pruning')
    optimizer_finetune = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9, decay=1e-4)
    pruner = create_pruner(model, args.pruner_name)
    model = pruner.compress()
    model.compile(optimizer=optimizer_finetune, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    model.fit(train_set[0], train_set[1], batch_size=args.batch_size, epochs=args.prune_epochs, validation_data=test_set)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--pruner_name', type=str, default='level')
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--pretrain_epochs', type=int, default=10)
    parser.add_argument('--prune_epochs', type=int, default=10)

    args = parser.parse_args()
    main(args)