mnist.py 5.99 KB
Newer Older
1
2
3
'''
mnist.py is an example to show: how to use iterative search space to tune architecture network for mnist.
'''
4
from __future__ import absolute_import, division, print_function
5
6
7
8

import logging
import math
import tempfile
9
import time
demianzhang's avatar
demianzhang committed
10
import argparse
11

12
import tensorflow as tf
13
14
15
16
from tensorflow.examples.tutorials.mnist import input_data

import nni

17
logger = logging.getLogger('mnist_nested_search_space')
18
19
20
21
22
23
FLAGS = None

class MnistNetwork(object):
    def __init__(self, params, feature_size = 784):
        config = []

demianzhang's avatar
demianzhang committed
24
        for i in range(4):
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
            config.append(params['layer'+str(i)])
        self.config = config
        self.feature_size = feature_size
        self.label_size = 10


    def is_expand_dim(self, input):
        # input is a tensor
        shape = len(input.get_shape().as_list())
        if shape < 4:
            return True
        return False


    def is_flatten(self, input):
        # input is a tensor
        shape = len(input.get_shape().as_list())
        if shape > 2:
            return True
        return False


    def get_layer(self, layer_config, input, in_height, in_width, id):
        if layer_config[0] == 'Empty':
            return input

        if self.is_expand_dim(input):
            input = tf.reshape(input, [-1, in_height, in_width, 1])
        h, w = layer_config[1], layer_config[2]

        if layer_config[0] == 'Conv':
            conv_filter = tf.Variable(tf.random_uniform([h, w, 1, 1]), name='id_%d_conv_%d_%d' % (id, h, w))
            return tf.nn.conv2d(input, filter=conv_filter, strides=[1, 1, 1, 1], padding='SAME')
        if layer_config[0] == 'Max_pool':
            return tf.nn.max_pool(input, ksize=[1, h, w, 1], strides=[1, 1, 1, 1], padding='SAME')
        if layer_config[0] == 'Avg_pool':
            return tf.nn.avg_pool(input, ksize=[1, h, w, 1], strides=[1, 1, 1, 1], padding='SAME')

        print('error:', layer_config)
        raise Exception('%s layer is illegal'%layer_config[0])


    def build_network(self):
        layer_configs = self.config
        feature_size = 784

        # define placeholder
        self.x = tf.placeholder(tf.float32, [None, feature_size], name="input_x")
        self.y = tf.placeholder(tf.int32, [None, self.label_size], name="input_y")
        label_number = 10

        # define network
        input_layer = self.x
        in_height = in_width = int(math.sqrt(feature_size))
        for i, layer_config in enumerate(layer_configs):
            input_layer = tf.nn.relu(self.get_layer(layer_config, input_layer, in_height, in_width, i))

        output_layer = input_layer
        if self.is_flatten(output_layer):
            output_layer = tf.contrib.layers.flatten(output_layer)  # flatten
        output_layer = tf.layers.dense(output_layer, label_number)
        child_logit = tf.nn.softmax_cross_entropy_with_logits(logits=output_layer, labels=self.y)
        child_loss = tf.reduce_mean(child_logit)

        self.train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(child_loss)
        child_accuracy = tf.equal(tf.argmax(output_layer, 1), tf.argmax(self.y, 1))
        self.accuracy = tf.reduce_mean(tf.cast(child_accuracy, "float"))  # add a reduce_mean

93
94
95
96
97
98
99
100
101
def download_mnist_retry(data_dir, max_num_retries=20):
    """Try to download mnist dataset and avoid errors"""
    for _ in range(max_num_retries):
        try:
            return input_data.read_data_sets(data_dir, one_hot=True)
        except tf.errors.AlreadyExistsError:
            time.sleep(1)
    raise Exception("Failed to download MNIST.")

102
103
def main(params):
    # Import data
104
105
    mnist = download_mnist_retry(params['data_dir'])

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    # Create the model
    # Build the graph for the deep net
    mnist_network = MnistNetwork(params)
    mnist_network.build_network()
    print('build network done.')

    # Write log
    graph_location = tempfile.mkdtemp()
    #print('Saving graph to: %s' % graph_location)
    train_writer = tf.summary.FileWriter(graph_location)
    train_writer.add_graph(tf.get_default_graph())

    test_acc = 0.0
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(params['batch_num']):
            batch = mnist.train.next_batch(params['batch_size'])
            mnist_network.train_step.run(feed_dict={mnist_network.x: batch[0], mnist_network.y: batch[1]})
124

125
126
127
128
            if i % 100 == 0:
                train_accuracy = mnist_network.accuracy.eval(feed_dict={
                    mnist_network.x: batch[0], mnist_network.y: batch[1]})
                print('step %d, training accuracy %g' % (i, train_accuracy))
129

130
131
        test_acc = mnist_network.accuracy.eval(feed_dict={
            mnist_network.x: mnist.test.images, mnist_network.y: mnist.test.labels})
132

133
134
        nni.report_final_result(test_acc)

Lee's avatar
Lee committed
135
136
137
138
139
140
141
142
def get_params():
    ''' Get parameters from command line '''
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, default='/tmp/tensorflow/mnist/input_data', help="data directory")
    parser.add_argument("--batch_num", type=int, default=1000)
    parser.add_argument("--batch_size", type=int, default=200)
    args, _ = parser.parse_known_args()
    return args
143
144
145
146
147

def parse_init_json(data):
    params = {}
    for key in data:
        value = data[key]
Lee's avatar
Lee committed
148
149
150
        layer_name = value["_name"]
        if layer_name == 'Empty':
            # Empty Layer
151
            params[key] = ['Empty']
Lee's avatar
Lee committed
152
153
154
        elif layer_name == 'Conv':
            # Conv layer
            params[key] = [layer_name, value['kernel_size'], value['kernel_size']]
155
        else:
Lee's avatar
Lee committed
156
157
            # Pooling Layer
            params[key] = [layer_name, value['pooling_size'], value['pooling_size']]
158
159
160
161
162
163
    return params


if __name__ == '__main__':
    try:
        # get parameters form tuner
chicm-ms's avatar
chicm-ms committed
164
        data = nni.get_next_parameter()
165
166
167
168
        logger.debug(data)

        RCV_PARAMS = parse_init_json(data)
        logger.debug(RCV_PARAMS)
Lee's avatar
Lee committed
169
        params = vars(get_params())
170
171
172
173
174
175
176
        params.update(RCV_PARAMS)
        print(RCV_PARAMS)

        main(params)
    except Exception as exception:
        logger.exception(exception)
        raise