Unverified Commit 3aae9d06 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Fix tensorflow import (#2481)

parent 241b364b
......@@ -2,7 +2,6 @@
# Licensed under the MIT license.
import tensorflow as tf
from tensorflow.data import Dataset
def get_dataset():
(x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.cifar10.load_data()
......
......@@ -4,7 +4,6 @@
import logging
import tensorflow as tf
from tensorflow.data import Dataset
from tensorflow.keras.optimizers import Adam
from nni.nas.tensorflow.utils import AverageMeterGroup, fill_zero_grads
......@@ -39,9 +38,9 @@ class EnasTrainer:
x, y = dataset_train
split = int(len(x) * 0.9)
self.train_set = Dataset.from_tensor_slices((x[:split], y[:split]))
self.valid_set = Dataset.from_tensor_slices((x[split:], y[split:]))
self.test_set = Dataset.from_tensor_slices(dataset_valid)
self.train_set = tf.data.Dataset.from_tensor_slices((x[:split], y[:split]))
self.valid_set = tf.data.Dataset.from_tensor_slices((x[split:], y[split:]))
self.test_set = tf.data.Dataset.from_tensor_slices(dataset_valid)
self.mutator = EnasMutator(model)
self.mutator_optim = Adam(learning_rate=mutator_lr)
......@@ -151,9 +150,9 @@ class EnasTrainer:
def _create_train_loader(self):
train_set = self.train_set.shuffle(1000000).batch(self.batch_size)
test_set = self.test_set.shuffle(1000000).batch(self.batch_size)
train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size)
test_set = self.test_set.shuffle(1000000).repeat().batch(self.batch_size)
return iter(train_set), iter(test_set)
def _create_validate_loader(self):
return iter(self.test_set.shuffle(1000000).batch(self.batch_size))
return iter(self.test_set.shuffle(1000000).repeat().batch(self.batch_size))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment