Unverified Commit 4e2d8cd8 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Fix TF NAS naive example (#2948)

parent cd23bc41
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import Model from tensorflow.keras import Model
from tensorflow.keras.layers import (AveragePooling2D, BatchNormalization, Conv2D, Dense, MaxPool2D) from tensorflow.keras.layers import (AveragePooling2D, BatchNormalization, Conv2D, Dense, MaxPool2D)
...@@ -7,8 +10,6 @@ from tensorflow.keras.optimizers import SGD ...@@ -7,8 +10,6 @@ from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.enas import EnasTrainer from nni.nas.tensorflow.enas import EnasTrainer
tf.get_logger().setLevel('ERROR')
class Net(Model): class Net(Model):
def __init__(self): def __init__(self):
...@@ -53,35 +54,36 @@ class Net(Model): ...@@ -53,35 +54,36 @@ class Net(Model):
return x return x
def accuracy(output, target): def accuracy(truth, logits):
bs = target.shape[0] truth = tf.reshape(truth, -1)
predicted = tf.cast(tf.argmax(output, 1), target.dtype) predicted = tf.cast(tf.math.argmax(logits, axis=1), truth.dtype)
target = tf.reshape(target, [-1]) equal = tf.cast(predicted == truth, tf.int32)
return sum(tf.cast(predicted == target, tf.float32)) / bs return tf.math.reduce_sum(equal).numpy() / equal.shape[0]
def accuracy_metrics(truth, logits):
acc = accuracy(truth, logits)
return {'accuracy': acc}
if __name__ == '__main__': if __name__ == '__main__':
cifar10 = tf.keras.datasets.cifar10 cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data() (x_train, y_train), (x_valid, y_valid) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 x_train, x_valid = x_train / 255.0, x_valid / 255.0
split = int(len(x_train) * 0.9) train_set = (x_train, y_train)
dataset_train = tf.data.Dataset.from_tensor_slices((x_train[:split], y_train[:split])).batch(64) valid_set = (x_valid, y_valid)
dataset_valid = tf.data.Dataset.from_tensor_slices((x_train[split:], y_train[split:])).batch(64)
dataset_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(64)
net = Net() net = Net()
trainer = EnasTrainer( trainer = EnasTrainer(
net, net,
loss=SparseCategoricalCrossentropy(reduction=Reduction.SUM), loss=SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE),
metrics=accuracy, metrics=accuracy_metrics,
reward_function=accuracy, reward_function=accuracy,
optimizer=SGD(learning_rate=0.001, momentum=0.9), optimizer=SGD(learning_rate=0.001, momentum=0.9),
batch_size=64, batch_size=64,
num_epochs=2, num_epochs=2,
dataset_train=dataset_train, dataset_train=train_set,
dataset_valid=dataset_valid, dataset_valid=valid_set
dataset_test=dataset_test
) )
trainer.train() trainer.train()
#trainer.export('checkpoint')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment