"git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "cae1bfe737cbe2365cb9d8788093eaa758dcfb20"
Commit 2b4a7d1f authored by Lee's avatar Lee Committed by chicm-ms
Browse files

Fix tensorflow MNIST dataset download bug in multiple trials (#815)

* Fix tensorflow MNIST dataset download bug in multiple trials

* Modify the maxTrialNum and trialConcurrency in config
parent 2fa77bcc
...@@ -4,8 +4,9 @@ import argparse ...@@ -4,8 +4,9 @@ import argparse
import logging import logging
import math import math
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None FLAGS = None
...@@ -152,12 +153,21 @@ def bias_variable(shape): ...@@ -152,12 +153,21 @@ def bias_variable(shape):
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
''' '''
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
''' '''
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
......
''' '''
mnist.py is an example to show: how to use iterative search space to tune architecture network for mnist. mnist.py is an example to show: how to use iterative search space to tune architecture network for mnist.
''' '''
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import argparse
import codecs
import json
import logging import logging
import math import math
import sys
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
import nni import nni
logger = logging.getLogger('mnist_cascading_search_space') logger = logging.getLogger('mnist_cascading_search_space')
FLAGS = None FLAGS = None
...@@ -95,9 +89,18 @@ class MnistNetwork(object): ...@@ -95,9 +89,18 @@ class MnistNetwork(object):
child_accuracy = tf.equal(tf.argmax(output_layer, 1), tf.argmax(self.y, 1)) child_accuracy = tf.equal(tf.argmax(output_layer, 1), tf.argmax(self.y, 1))
self.accuracy = tf.reduce_mean(tf.cast(child_accuracy, "float")) # add a reduce_mean self.accuracy = tf.reduce_mean(tf.cast(child_accuracy, "float")) # add a reduce_mean
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
# Create the model # Create the model
# Build the graph for the deep net # Build the graph for the deep net
......
...@@ -54,16 +54,19 @@ import nni ...@@ -54,16 +54,19 @@ import nni
flags = tf.app.flags flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/mnist-data", flags.DEFINE_string("data_dir", "/tmp/mnist-data",
"Directory for storing mnist data") "Directory for storing mnist data")
flags.DEFINE_boolean("download_only", False, flags.DEFINE_boolean(
"download_only", False,
"Only perform downloading of data; Do not proceed to " "Only perform downloading of data; Do not proceed to "
"session preparation, model definition or training") "session preparation, model definition or training")
flags.DEFINE_integer("task_index", None, flags.DEFINE_integer(
"Worker task index, should be >= 0. task_index=0 is " "task_index", None, "Worker task index, should be >= 0. task_index=0 is "
"the master worker task the performs the variable " "the master worker task the performs the variable "
"initialization ") "initialization ")
flags.DEFINE_integer("num_gpus", 1, "Total number of gpus for each machine." flags.DEFINE_integer(
"num_gpus", 1, "Total number of gpus for each machine."
"If you don't use GPU, please set it to '0'") "If you don't use GPU, please set it to '0'")
flags.DEFINE_integer("replicas_to_aggregate", None, flags.DEFINE_integer(
"replicas_to_aggregate", None,
"Number of replicas to aggregate before parameter update" "Number of replicas to aggregate before parameter update"
"is applied (For sync_replicas mode only; default: " "is applied (For sync_replicas mode only; default: "
"num_workers)") "num_workers)")
...@@ -96,6 +99,7 @@ IMAGE_PIXELS = 28 ...@@ -96,6 +99,7 @@ IMAGE_PIXELS = 28
# {'cluster': cluster, # {'cluster': cluster,
# 'task': {'type': 'worker', 'index': 1}}) # 'task': {'type': 'worker', 'index': 1}})
def generate_default_params(): def generate_default_params():
''' '''
Generate default hyper parameters Generate default hyper parameters
...@@ -106,6 +110,15 @@ def generate_default_params(): ...@@ -106,6 +110,15 @@ def generate_default_params():
'hidden_units': 100, 'hidden_units': 100,
} }
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(unused_argv): def main(unused_argv):
# Receive NNI hyper parameter and update it onto default params # Receive NNI hyper parameter and update it onto default params
RECEIVED_PARAMS = nni.get_next_parameter() RECEIVED_PARAMS = nni.get_next_parameter()
...@@ -124,7 +137,7 @@ def main(unused_argv): ...@@ -124,7 +137,7 @@ def main(unused_argv):
FLAGS.job_name = task_type FLAGS.job_name = task_type
FLAGS.task_index = task_index FLAGS.task_index = task_index
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) mnist = download_mnist_retry(FLAGS.data_dir)
if FLAGS.download_only: if FLAGS.download_only:
sys.exit(0) sys.exit(0)
...@@ -206,7 +219,8 @@ def main(unused_argv): ...@@ -206,7 +219,8 @@ def main(unused_argv):
hid = tf.nn.relu(hid_lin) hid = tf.nn.relu(hid_lin)
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) cross_entropy = -tf.reduce_sum(
y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
opt = tf.train.AdamOptimizer(PARAMS['learning_rate']) opt = tf.train.AdamOptimizer(PARAMS['learning_rate'])
...@@ -258,8 +272,9 @@ def main(unused_argv): ...@@ -258,8 +272,9 @@ def main(unused_argv):
sess_config = tf.ConfigProto( sess_config = tf.ConfigProto(
allow_soft_placement=True, allow_soft_placement=True,
log_device_placement=False, log_device_placement=False,
device_filters=["/job:ps", device_filters=[
"/job:worker/task:%d" % FLAGS.task_index]) "/job:ps", "/job:worker/task:%d" % FLAGS.task_index
])
# The chief worker (task_index==0) session will prepare the session, # The chief worker (task_index==0) session will prepare the session,
# while the remaining workers will wait for the preparation to complete. # while the remaining workers will wait for the preparation to complete.
...@@ -273,9 +288,11 @@ def main(unused_argv): ...@@ -273,9 +288,11 @@ def main(unused_argv):
server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index] server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
print("Using existing server at: %s" % server_grpc_url) print("Using existing server at: %s" % server_grpc_url)
sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config) sess = sv.prepare_or_wait_for_session(
server_grpc_url, config=sess_config)
else: else:
sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) sess = sv.prepare_or_wait_for_session(
server.target, config=sess_config)
print("Worker %d: Session initialization complete." % FLAGS.task_index) print("Worker %d: Session initialization complete." % FLAGS.task_index)
...@@ -302,9 +319,14 @@ def main(unused_argv): ...@@ -302,9 +319,14 @@ def main(unused_argv):
(now, FLAGS.task_index, local_step, step)) (now, FLAGS.task_index, local_step, step))
if step > 0 and step % 5000 == 0 and is_chief: if step > 0 and step % 5000 == 0 and is_chief:
val_feed = {x: mnist.validation.images, y_: mnist.validation.labels} val_feed = {
x: mnist.validation.images,
y_: mnist.validation.labels
}
interim_val_xent = sess.run(cross_entropy, feed_dict=val_feed) interim_val_xent = sess.run(cross_entropy, feed_dict=val_feed)
print("After %d training step(s), validation cross entropy = %g" % (step, interim_val_xent)) print(
"After %d training step(s), validation cross entropy = %g"
% (step, interim_val_xent))
# Only chief worker can report intermediate metrics # Only chief worker can report intermediate metrics
nni.report_intermediate_result(interim_val_xent) nni.report_intermediate_result(interim_val_xent)
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
import logging import logging
import math import math
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
import nni import nni
...@@ -142,13 +143,21 @@ def bias_variable(shape): ...@@ -142,13 +143,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape) initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
''' '''
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
''' '''
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
......
...@@ -4,8 +4,9 @@ import argparse ...@@ -4,8 +4,9 @@ import argparse
import logging import logging
import math import math
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
import nni import nni
...@@ -143,13 +144,21 @@ def bias_variable(shape): ...@@ -143,13 +144,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape) initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
''' '''
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
''' '''
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
......
...@@ -3,8 +3,9 @@ import argparse ...@@ -3,8 +3,9 @@ import argparse
import logging import logging
import math import math
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None FLAGS = None
...@@ -143,13 +144,21 @@ def bias_variable(shape): ...@@ -143,13 +144,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape) initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
''' '''
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
''' '''
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
......
authorName: nni authorName: nni
experimentName: default_test experimentName: default_test
maxExecDuration: 5m maxExecDuration: 5m
maxTrialNum: 2 maxTrialNum: 4
trialConcurrency: 1 trialConcurrency: 2
tuner: tuner:
builtinTunerName: Random builtinTunerName: Random
......
authorName: nni authorName: nni
experimentName: default_test experimentName: default_test
maxExecDuration: 5m maxExecDuration: 5m
maxTrialNum: 2 maxTrialNum: 4
trialConcurrency: 1 trialConcurrency: 2
searchSpacePath: ./mnist_search_space.json searchSpacePath: ./mnist_search_space.json
tuner: tuner:
......
import nni
"""A deep MNIST classifier using convolutional layers.""" """A deep MNIST classifier using convolutional layers."""
import logging import logging
import math import math
import tempfile import tempfile
import time
import tensorflow as tf import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
import nni
FLAGS = None FLAGS = None
logger = logging.getLogger('mnist_AutoML') logger = logging.getLogger('mnist_AutoML')
...@@ -123,12 +127,23 @@ def bias_variable(shape): ...@@ -123,12 +127,23 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape) initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
""" """
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
""" """
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True)
def main(params):
# Import data
mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
mnist_network = MnistNetwork(channel_1_num=params['channel_1_num'], mnist_network = MnistNetwork(channel_1_num=params['channel_1_num'],
......
...@@ -21,8 +21,9 @@ ...@@ -21,8 +21,9 @@
import logging import logging
import math import math
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None FLAGS = None
...@@ -168,13 +169,21 @@ def bias_variable(shape): ...@@ -168,13 +169,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape) initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
''' '''
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
''' '''
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
......
...@@ -21,8 +21,9 @@ ...@@ -21,8 +21,9 @@
import logging import logging
import math import math
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
import nni import nni
...@@ -172,13 +173,21 @@ def bias_variable(shape): ...@@ -172,13 +173,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape) initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
''' '''
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
''' '''
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment