train.py 2.59 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

liuzhe-lz's avatar
liuzhe-lz committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import (AveragePooling2D, BatchNormalization, Conv2D, Dense, MaxPool2D)
from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD

from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.enas import EnasTrainer


class Net(Model):
    def __init__(self):
        super().__init__()
        self.conv1 = LayerChoice([
            Conv2D(6, 3, padding='same', activation='relu'),
            Conv2D(6, 5, padding='same', activation='relu'),
        ])
        self.pool = MaxPool2D(2)
        self.conv2 = LayerChoice([
            Conv2D(16, 3, padding='same', activation='relu'),
            Conv2D(16, 5, padding='same', activation='relu'),
        ])
        self.conv3 = Conv2D(16, 1)

        self.skipconnect = InputChoice(n_candidates=1)
        self.bn = BatchNormalization()

        self.gap = AveragePooling2D(2)
        self.fc1 = Dense(120, activation='relu')
        self.fc2 = Dense(84, activation='relu')
        self.fc3 = Dense(10)

    def call(self, x):
        bs = x.shape[0]

        t = self.conv1(x)
        x = self.pool(t)
        x0 = self.conv2(x)
        x1 = self.conv3(x0)

        x0 = self.skipconnect([x0])
        if x0 is not None:
            x1 += x0
        x = self.pool(self.bn(x1))

        x = self.gap(x)
        x = tf.reshape(x, [bs, -1])
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


liuzhe-lz's avatar
liuzhe-lz committed
57
58
59
60
61
62
63
64
65
def accuracy(truth, logits):
    truth = tf.reshape(truth, -1)
    predicted = tf.cast(tf.math.argmax(logits, axis=1), truth.dtype)
    equal = tf.cast(predicted == truth, tf.int32)
    return tf.math.reduce_sum(equal).numpy() / equal.shape[0]

def accuracy_metrics(truth, logits):
    acc = accuracy(truth, logits)
    return {'accuracy': acc}
liuzhe-lz's avatar
liuzhe-lz committed
66
67
68
69


if __name__ == '__main__':
    cifar10 = tf.keras.datasets.cifar10
liuzhe-lz's avatar
liuzhe-lz committed
70
71
72
73
    (x_train, y_train), (x_valid, y_valid) = cifar10.load_data()
    x_train, x_valid = x_train / 255.0, x_valid / 255.0
    train_set = (x_train, y_train)
    valid_set = (x_valid, y_valid)
liuzhe-lz's avatar
liuzhe-lz committed
74
75

    net = Net()
liuzhe-lz's avatar
liuzhe-lz committed
76

liuzhe-lz's avatar
liuzhe-lz committed
77
78
    trainer = EnasTrainer(
        net,
liuzhe-lz's avatar
liuzhe-lz committed
79
80
        loss=SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE),
        metrics=accuracy_metrics,
liuzhe-lz's avatar
liuzhe-lz committed
81
82
83
84
        reward_function=accuracy,
        optimizer=SGD(learning_rate=0.001, momentum=0.9),
        batch_size=64,
        num_epochs=2,
liuzhe-lz's avatar
liuzhe-lz committed
85
86
        dataset_train=train_set,
        dataset_valid=valid_set
liuzhe-lz's avatar
liuzhe-lz committed
87
88
89
    )

    trainer.train()