Unverified Commit 4e2d8cd8 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Fix TF NAS naive example (#2948)

parent cd23bc41
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import (AveragePooling2D, BatchNormalization, Conv2D, Dense, MaxPool2D)
......@@ -7,8 +10,6 @@ from tensorflow.keras.optimizers import SGD
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice
from nni.nas.tensorflow.enas import EnasTrainer
tf.get_logger().setLevel('ERROR')
class Net(Model):
def __init__(self):
......@@ -53,35 +54,36 @@ class Net(Model):
return x
def accuracy(output, target):
bs = target.shape[0]
predicted = tf.cast(tf.argmax(output, 1), target.dtype)
target = tf.reshape(target, [-1])
return sum(tf.cast(predicted == target, tf.float32)) / bs
def accuracy(truth, logits):
truth = tf.reshape(truth, -1)
predicted = tf.cast(tf.math.argmax(logits, axis=1), truth.dtype)
equal = tf.cast(predicted == truth, tf.int32)
return tf.math.reduce_sum(equal).numpy() / equal.shape[0]
def accuracy_metrics(truth, logits):
acc = accuracy(truth, logits)
return {'accuracy': acc}
if __name__ == '__main__':
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
split = int(len(x_train) * 0.9)
dataset_train = tf.data.Dataset.from_tensor_slices((x_train[:split], y_train[:split])).batch(64)
dataset_valid = tf.data.Dataset.from_tensor_slices((x_train[split:], y_train[split:])).batch(64)
dataset_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(64)
(x_train, y_train), (x_valid, y_valid) = cifar10.load_data()
x_train, x_valid = x_train / 255.0, x_valid / 255.0
train_set = (x_train, y_train)
valid_set = (x_valid, y_valid)
net = Net()
trainer = EnasTrainer(
net,
loss=SparseCategoricalCrossentropy(reduction=Reduction.SUM),
metrics=accuracy,
loss=SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE),
metrics=accuracy_metrics,
reward_function=accuracy,
optimizer=SGD(learning_rate=0.001, momentum=0.9),
batch_size=64,
num_epochs=2,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
dataset_test=dataset_test
dataset_train=train_set,
dataset_valid=valid_set
)
trainer.train()
#trainer.export('checkpoint')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment