search.py 1.03 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.


from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD

from nni.nas.tensorflow import enas

import datasets
from macro import GeneralNetwork
from micro import MicroNetwork
from utils import accuracy, accuracy_metrics


# TODO: argparse


dataset_train, dataset_valid = datasets.get_dataset()
#model = GeneralNetwork()
model = MicroNetwork()

loss = SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE)
optimizer = SGD(learning_rate=0.05, momentum=0.9)

trainer = enas.EnasTrainer(model,
                           loss=loss,
                           metrics=accuracy_metrics,
                           reward_function=accuracy,
                           optimizer=optimizer,
                           batch_size=64,
                           num_epochs=310,
                           dataset_train=dataset_train,
                           dataset_valid=dataset_valid)
trainer.train()