Unverified Commit 3aae9d06 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Fix tensorflow import (#2481)

parent 241b364b
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import tensorflow as tf import tensorflow as tf
from tensorflow.data import Dataset
def get_dataset(): def get_dataset():
(x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.cifar10.load_data() (x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.cifar10.load_data()
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import logging import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.data import Dataset
from tensorflow.keras.optimizers import Adam from tensorflow.keras.optimizers import Adam
from nni.nas.tensorflow.utils import AverageMeterGroup, fill_zero_grads from nni.nas.tensorflow.utils import AverageMeterGroup, fill_zero_grads
...@@ -39,9 +38,9 @@ class EnasTrainer: ...@@ -39,9 +38,9 @@ class EnasTrainer:
x, y = dataset_train x, y = dataset_train
split = int(len(x) * 0.9) split = int(len(x) * 0.9)
self.train_set = Dataset.from_tensor_slices((x[:split], y[:split])) self.train_set = tf.data.Dataset.from_tensor_slices((x[:split], y[:split]))
self.valid_set = Dataset.from_tensor_slices((x[split:], y[split:])) self.valid_set = tf.data.Dataset.from_tensor_slices((x[split:], y[split:]))
self.test_set = Dataset.from_tensor_slices(dataset_valid) self.test_set = tf.data.Dataset.from_tensor_slices(dataset_valid)
self.mutator = EnasMutator(model) self.mutator = EnasMutator(model)
self.mutator_optim = Adam(learning_rate=mutator_lr) self.mutator_optim = Adam(learning_rate=mutator_lr)
...@@ -151,9 +150,9 @@ class EnasTrainer: ...@@ -151,9 +150,9 @@ class EnasTrainer:
def _create_train_loader(self): def _create_train_loader(self):
train_set = self.train_set.shuffle(1000000).batch(self.batch_size) train_set = self.train_set.shuffle(1000000).repeat().batch(self.batch_size)
test_set = self.test_set.shuffle(1000000).batch(self.batch_size) test_set = self.test_set.shuffle(1000000).repeat().batch(self.batch_size)
return iter(train_set), iter(test_set) return iter(train_set), iter(test_set)
def _create_validate_loader(self): def _create_validate_loader(self):
return iter(self.test_set.shuffle(1000000).batch(self.batch_size)) return iter(self.test_set.shuffle(1000000).repeat().batch(self.batch_size))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment