Commit 2b4a7d1f authored by Lee's avatar Lee Committed by chicm-ms
Browse files

Fix tensorflow MNIST dataset download bug in multiple trials (#815)

* Fix tensorflow MNIST dataset download bug in multiple trials

* Modify the maxTrialNum and trialConcurrency in config
parent 2fa77bcc
...@@ -4,8 +4,9 @@ import argparse ...@@ -4,8 +4,9 @@ import argparse
import logging import logging
import math import math
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None FLAGS = None
...@@ -152,12 +153,21 @@ def bias_variable(shape): ...@@ -152,12 +153,21 @@ def bias_variable(shape):
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
''' '''
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
''' '''
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
......
''' '''
mnist.py is an example to show: how to use iterative search space to tune architecture network for mnist. mnist.py is an example to show: how to use iterative search space to tune architecture network for mnist.
''' '''
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import argparse
import codecs
import json
import logging import logging
import math import math
import sys
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
import nni import nni
logger = logging.getLogger('mnist_cascading_search_space') logger = logging.getLogger('mnist_cascading_search_space')
FLAGS = None FLAGS = None
...@@ -95,10 +89,19 @@ class MnistNetwork(object): ...@@ -95,10 +89,19 @@ class MnistNetwork(object):
child_accuracy = tf.equal(tf.argmax(output_layer, 1), tf.argmax(self.y, 1)) child_accuracy = tf.equal(tf.argmax(output_layer, 1), tf.argmax(self.y, 1))
self.accuracy = tf.reduce_mean(tf.cast(child_accuracy, "float")) # add a reduce_mean self.accuracy = tf.reduce_mean(tf.cast(child_accuracy, "float")) # add a reduce_mean
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
# Create the model # Create the model
# Build the graph for the deep net # Build the graph for the deep net
mnist_network = MnistNetwork(params) mnist_network = MnistNetwork(params)
...@@ -117,15 +120,15 @@ def main(params): ...@@ -117,15 +120,15 @@ def main(params):
for i in range(params['batch_num']): for i in range(params['batch_num']):
batch = mnist.train.next_batch(params['batch_size']) batch = mnist.train.next_batch(params['batch_size'])
mnist_network.train_step.run(feed_dict={mnist_network.x: batch[0], mnist_network.y: batch[1]}) mnist_network.train_step.run(feed_dict={mnist_network.x: batch[0], mnist_network.y: batch[1]})
if i % 100 == 0: if i % 100 == 0:
train_accuracy = mnist_network.accuracy.eval(feed_dict={ train_accuracy = mnist_network.accuracy.eval(feed_dict={
mnist_network.x: batch[0], mnist_network.y: batch[1]}) mnist_network.x: batch[0], mnist_network.y: batch[1]})
print('step %d, training accuracy %g' % (i, train_accuracy)) print('step %d, training accuracy %g' % (i, train_accuracy))
test_acc = mnist_network.accuracy.eval(feed_dict={ test_acc = mnist_network.accuracy.eval(feed_dict={
mnist_network.x: mnist.test.images, mnist_network.y: mnist.test.labels}) mnist_network.x: mnist.test.images, mnist_network.y: mnist.test.labels})
nni.report_final_result(test_acc) nni.report_final_result(test_acc)
def generate_defualt_params(): def generate_defualt_params():
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
# #
# NNI (https://github.com/Microsoft/nni) modified this code to show how to # NNI (https://github.com/Microsoft/nni) modified this code to show how to
# integrate distributed tensorflow training with NNI SDK # integrate distributed tensorflow training with NNI SDK
# #
"""Distributed MNIST training and validation, with model replicas. """Distributed MNIST training and validation, with model replicas.
A simple softmax model with one hidden layer is defined. The parameters A simple softmax model with one hidden layer is defined. The parameters
...@@ -54,19 +54,22 @@ import nni ...@@ -54,19 +54,22 @@ import nni
flags = tf.app.flags flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/mnist-data", flags.DEFINE_string("data_dir", "/tmp/mnist-data",
"Directory for storing mnist data") "Directory for storing mnist data")
flags.DEFINE_boolean("download_only", False, flags.DEFINE_boolean(
"Only perform downloading of data; Do not proceed to " "download_only", False,
"session preparation, model definition or training") "Only perform downloading of data; Do not proceed to "
flags.DEFINE_integer("task_index", None, "session preparation, model definition or training")
"Worker task index, should be >= 0. task_index=0 is " flags.DEFINE_integer(
"the master worker task the performs the variable " "task_index", None, "Worker task index, should be >= 0. task_index=0 is "
"initialization ") "the master worker task the performs the variable "
flags.DEFINE_integer("num_gpus", 1, "Total number of gpus for each machine." "initialization ")
"If you don't use GPU, please set it to '0'") flags.DEFINE_integer(
flags.DEFINE_integer("replicas_to_aggregate", None, "num_gpus", 1, "Total number of gpus for each machine."
"Number of replicas to aggregate before parameter update" "If you don't use GPU, please set it to '0'")
"is applied (For sync_replicas mode only; default: " flags.DEFINE_integer(
"num_workers)") "replicas_to_aggregate", None,
"Number of replicas to aggregate before parameter update"
"is applied (For sync_replicas mode only; default: "
"num_workers)")
flags.DEFINE_integer("train_steps", 20000, flags.DEFINE_integer("train_steps", 20000,
"Number of (global) training steps to perform") "Number of (global) training steps to perform")
flags.DEFINE_boolean( flags.DEFINE_boolean(
...@@ -96,237 +99,256 @@ IMAGE_PIXELS = 28 ...@@ -96,237 +99,256 @@ IMAGE_PIXELS = 28
# {'cluster': cluster, # {'cluster': cluster,
# 'task': {'type': 'worker', 'index': 1}}) # 'task': {'type': 'worker', 'index': 1}})
def generate_default_params(): def generate_default_params():
''' '''
Generate default hyper parameters Generate default hyper parameters
''' '''
return { return {
'learning_rate': 0.01, 'learning_rate': 0.01,
'batch_size': 100, 'batch_size': 100,
'hidden_units': 100, 'hidden_units': 100,
} }
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(unused_argv): def main(unused_argv):
# Receive NNI hyper parameter and update it onto default params # Receive NNI hyper parameter and update it onto default params
RECEIVED_PARAMS = nni.get_next_parameter() RECEIVED_PARAMS = nni.get_next_parameter()
PARAMS = generate_default_params() PARAMS = generate_default_params()
PARAMS.update(RECEIVED_PARAMS) PARAMS.update(RECEIVED_PARAMS)
# Parse environment variable TF_CONFIG to get job_name and task_index # Parse environment variable TF_CONFIG to get job_name and task_index
# If not explicitly specified in the constructor and the TF_CONFIG # If not explicitly specified in the constructor and the TF_CONFIG
# environment variable is present, load cluster_spec from TF_CONFIG. # environment variable is present, load cluster_spec from TF_CONFIG.
tf_config = json.loads(os.environ.get('TF_CONFIG') or '{}') tf_config = json.loads(os.environ.get('TF_CONFIG') or '{}')
task_config = tf_config.get('task', {}) task_config = tf_config.get('task', {})
task_type = task_config.get('type') task_type = task_config.get('type')
task_index = task_config.get('index') task_index = task_config.get('index')
FLAGS.job_name = task_type FLAGS.job_name = task_type
FLAGS.task_index = task_index FLAGS.task_index = task_index
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) mnist = download_mnist_retry(FLAGS.data_dir)
if FLAGS.download_only: if FLAGS.download_only:
sys.exit(0) sys.exit(0)
if FLAGS.job_name is None or FLAGS.job_name == "": if FLAGS.job_name is None or FLAGS.job_name == "":
raise ValueError("Must specify an explicit `job_name`") raise ValueError("Must specify an explicit `job_name`")
if FLAGS.task_index is None or FLAGS.task_index == "": if FLAGS.task_index is None or FLAGS.task_index == "":
raise ValueError("Must specify an explicit `task_index`") raise ValueError("Must specify an explicit `task_index`")
print("job name = %s" % FLAGS.job_name) print("job name = %s" % FLAGS.job_name)
print("task index = %d" % FLAGS.task_index) print("task index = %d" % FLAGS.task_index)
cluster_config = tf_config.get('cluster', {}) cluster_config = tf_config.get('cluster', {})
ps_hosts = cluster_config.get('ps') ps_hosts = cluster_config.get('ps')
worker_hosts = cluster_config.get('worker') worker_hosts = cluster_config.get('worker')
ps_hosts_str = ','.join(ps_hosts) ps_hosts_str = ','.join(ps_hosts)
worker_hosts_str = ','.join(worker_hosts) worker_hosts_str = ','.join(worker_hosts)
FLAGS.ps_hosts = ps_hosts_str FLAGS.ps_hosts = ps_hosts_str
FLAGS.worker_hosts = worker_hosts_str FLAGS.worker_hosts = worker_hosts_str
# Construct the cluster and start the server # Construct the cluster and start the server
ps_spec = FLAGS.ps_hosts.split(",") ps_spec = FLAGS.ps_hosts.split(",")
worker_spec = FLAGS.worker_hosts.split(",") worker_spec = FLAGS.worker_hosts.split(",")
# Get the number of workers. # Get the number of workers.
num_workers = len(worker_spec) num_workers = len(worker_spec)
cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec}) cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec})
if not FLAGS.existing_servers: if not FLAGS.existing_servers:
# Not using existing servers. Create an in-process server. # Not using existing servers. Create an in-process server.
server = tf.train.Server( server = tf.train.Server(
cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
if FLAGS.job_name == "ps": if FLAGS.job_name == "ps":
server.join() server.join()
is_chief = (FLAGS.task_index == 0) is_chief = (FLAGS.task_index == 0)
if FLAGS.num_gpus > 0: if FLAGS.num_gpus > 0:
# Avoid gpu allocation conflict: now allocate task_num -> #gpu # Avoid gpu allocation conflict: now allocate task_num -> #gpu
# for each worker in the corresponding machine # for each worker in the corresponding machine
gpu = (FLAGS.task_index % FLAGS.num_gpus) gpu = (FLAGS.task_index % FLAGS.num_gpus)
worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu) worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
elif FLAGS.num_gpus == 0: elif FLAGS.num_gpus == 0:
# Just allocate the CPU to worker server # Just allocate the CPU to worker server
cpu = 0 cpu = 0
worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu) worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
# The device setter will automatically place Variables ops on separate # The device setter will automatically place Variables ops on separate
# parameter servers (ps). The non-Variable ops will be placed on the workers. # parameter servers (ps). The non-Variable ops will be placed on the workers.
# The ps use CPU and workers use corresponding GPU # The ps use CPU and workers use corresponding GPU
with tf.device( with tf.device(
tf.train.replica_device_setter( tf.train.replica_device_setter(
worker_device=worker_device, worker_device=worker_device,
ps_device="/job:ps/cpu:0", ps_device="/job:ps/cpu:0",
cluster=cluster)): cluster=cluster)):
global_step = tf.Variable(0, name="global_step", trainable=False) global_step = tf.Variable(0, name="global_step", trainable=False)
# Variables of the hidden layer # Variables of the hidden layer
hid_w = tf.Variable( hid_w = tf.Variable(
tf.truncated_normal( tf.truncated_normal(
[IMAGE_PIXELS * IMAGE_PIXELS, PARAMS['hidden_units']], [IMAGE_PIXELS * IMAGE_PIXELS, PARAMS['hidden_units']],
stddev=1.0 / IMAGE_PIXELS), stddev=1.0 / IMAGE_PIXELS),
name="hid_w") name="hid_w")
hid_b = tf.Variable(tf.zeros([PARAMS['hidden_units']]), name="hid_b") hid_b = tf.Variable(tf.zeros([PARAMS['hidden_units']]), name="hid_b")
# Variables of the softmax layer # Variables of the softmax layer
sm_w = tf.Variable( sm_w = tf.Variable(
tf.truncated_normal( tf.truncated_normal(
[PARAMS['hidden_units'], 10], [PARAMS['hidden_units'], 10],
stddev=1.0 / math.sqrt(PARAMS['hidden_units'])), stddev=1.0 / math.sqrt(PARAMS['hidden_units'])),
name="sm_w") name="sm_w")
sm_b = tf.Variable(tf.zeros([10]), name="sm_b") sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
# Ops: located on the worker specified with FLAGS.task_index # Ops: located on the worker specified with FLAGS.task_index
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS]) x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, [None, 10]) y_ = tf.placeholder(tf.float32, [None, 10])
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin) hid = tf.nn.relu(hid_lin)
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) cross_entropy = -tf.reduce_sum(
y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
opt = tf.train.AdamOptimizer(PARAMS['learning_rate'])
opt = tf.train.AdamOptimizer(PARAMS['learning_rate'])
if FLAGS.sync_replicas:
if FLAGS.replicas_to_aggregate is None: if FLAGS.sync_replicas:
replicas_to_aggregate = num_workers if FLAGS.replicas_to_aggregate is None:
else: replicas_to_aggregate = num_workers
replicas_to_aggregate = FLAGS.replicas_to_aggregate else:
replicas_to_aggregate = FLAGS.replicas_to_aggregate
opt = tf.train.SyncReplicasOptimizer(
opt, opt = tf.train.SyncReplicasOptimizer(
replicas_to_aggregate=replicas_to_aggregate, opt,
total_num_replicas=num_workers, replicas_to_aggregate=replicas_to_aggregate,
name="mnist_sync_replicas") total_num_replicas=num_workers,
name="mnist_sync_replicas")
train_step = opt.minimize(cross_entropy, global_step=global_step)
train_step = opt.minimize(cross_entropy, global_step=global_step)
if FLAGS.sync_replicas:
local_init_op = opt.local_step_init_op if FLAGS.sync_replicas:
if is_chief: local_init_op = opt.local_step_init_op
local_init_op = opt.chief_init_op if is_chief:
local_init_op = opt.chief_init_op
ready_for_local_init_op = opt.ready_for_local_init_op
ready_for_local_init_op = opt.ready_for_local_init_op
# Initial token and chief queue runners required by the sync_replicas mode
chief_queue_runner = opt.get_chief_queue_runner() # Initial token and chief queue runners required by the sync_replicas mode
sync_init_op = opt.get_init_tokens_op() chief_queue_runner = opt.get_chief_queue_runner()
sync_init_op = opt.get_init_tokens_op()
init_op = tf.global_variables_initializer()
train_dir = tempfile.mkdtemp() init_op = tf.global_variables_initializer()
train_dir = tempfile.mkdtemp()
if FLAGS.sync_replicas:
sv = tf.train.Supervisor( if FLAGS.sync_replicas:
is_chief=is_chief, sv = tf.train.Supervisor(
logdir=train_dir, is_chief=is_chief,
init_op=init_op, logdir=train_dir,
local_init_op=local_init_op, init_op=init_op,
ready_for_local_init_op=ready_for_local_init_op, local_init_op=local_init_op,
recovery_wait_secs=1, ready_for_local_init_op=ready_for_local_init_op,
global_step=global_step) recovery_wait_secs=1,
else: global_step=global_step)
sv = tf.train.Supervisor( else:
is_chief=is_chief, sv = tf.train.Supervisor(
logdir=train_dir, is_chief=is_chief,
init_op=init_op, logdir=train_dir,
recovery_wait_secs=1, init_op=init_op,
global_step=global_step) recovery_wait_secs=1,
global_step=global_step)
sess_config = tf.ConfigProto(
allow_soft_placement=True, sess_config = tf.ConfigProto(
log_device_placement=False, allow_soft_placement=True,
device_filters=["/job:ps", log_device_placement=False,
"/job:worker/task:%d" % FLAGS.task_index]) device_filters=[
"/job:ps", "/job:worker/task:%d" % FLAGS.task_index
# The chief worker (task_index==0) session will prepare the session, ])
# while the remaining workers will wait for the preparation to complete.
if is_chief: # The chief worker (task_index==0) session will prepare the session,
print("Worker %d: Initializing session..." % FLAGS.task_index) # while the remaining workers will wait for the preparation to complete.
else: if is_chief:
print("Worker %d: Waiting for session to be initialized..." % print("Worker %d: Initializing session..." % FLAGS.task_index)
FLAGS.task_index) else:
print("Worker %d: Waiting for session to be initialized..." %
if FLAGS.existing_servers: FLAGS.task_index)
server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
print("Using existing server at: %s" % server_grpc_url) if FLAGS.existing_servers:
server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config) print("Using existing server at: %s" % server_grpc_url)
else:
sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) sess = sv.prepare_or_wait_for_session(
server_grpc_url, config=sess_config)
print("Worker %d: Session initialization complete." % FLAGS.task_index) else:
sess = sv.prepare_or_wait_for_session(
if FLAGS.sync_replicas and is_chief: server.target, config=sess_config)
# Chief worker will start the chief queue runner and call the init op.
sess.run(sync_init_op) print("Worker %d: Session initialization complete." % FLAGS.task_index)
sv.start_queue_runners(sess, [chief_queue_runner])
if FLAGS.sync_replicas and is_chief:
# Perform training # Chief worker will start the chief queue runner and call the init op.
time_begin = time.time() sess.run(sync_init_op)
print("Training begins @ %f" % time_begin) sv.start_queue_runners(sess, [chief_queue_runner])
local_step = 0 # Perform training
while True: time_begin = time.time()
# Training feed print("Training begins @ %f" % time_begin)
batch_xs, batch_ys = mnist.train.next_batch(PARAMS['batch_size'])
train_feed = {x: batch_xs, y_: batch_ys} local_step = 0
while True:
_, step = sess.run([train_step, global_step], feed_dict=train_feed) # Training feed
local_step += 1 batch_xs, batch_ys = mnist.train.next_batch(PARAMS['batch_size'])
train_feed = {x: batch_xs, y_: batch_ys}
now = time.time()
print("%f: Worker %d: training step %d done (global step: %d)" % _, step = sess.run([train_step, global_step], feed_dict=train_feed)
(now, FLAGS.task_index, local_step, step)) local_step += 1
if step > 0 and step % 5000 == 0 and is_chief: now = time.time()
print("%f: Worker %d: training step %d done (global step: %d)" %
(now, FLAGS.task_index, local_step, step))
if step > 0 and step % 5000 == 0 and is_chief:
val_feed = {
x: mnist.validation.images,
y_: mnist.validation.labels
}
interim_val_xent = sess.run(cross_entropy, feed_dict=val_feed)
print(
"After %d training step(s), validation cross entropy = %g"
% (step, interim_val_xent))
# Only chief worker can report intermediate metrics
nni.report_intermediate_result(interim_val_xent)
if step >= FLAGS.train_steps:
break
time_end = time.time()
print("Training ends @ %f" % time_end)
training_time = time_end - time_begin
print("Training elapsed time: %f s" % training_time)
# Validation feed
val_feed = {x: mnist.validation.images, y_: mnist.validation.labels} val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
interim_val_xent = sess.run(cross_entropy, feed_dict=val_feed) val_xent = sess.run(cross_entropy, feed_dict=val_feed)
print("After %d training step(s), validation cross entropy = %g" % (step, interim_val_xent)) print("After %d training step(s), validation cross entropy = %g" %
(FLAGS.train_steps, val_xent))
# Only chief worker can report intermediate metrics
nni.report_intermediate_result(interim_val_xent)
if step >= FLAGS.train_steps:
break
time_end = time.time()
print("Training ends @ %f" % time_end)
training_time = time_end - time_begin
print("Training elapsed time: %f s" % training_time)
# Validation feed
val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
val_xent = sess.run(cross_entropy, feed_dict=val_feed)
print("After %d training step(s), validation cross entropy = %g" %
(FLAGS.train_steps, val_xent))
# Only chief worker can report final metrics # Only chief worker can report final metrics
if is_chief: if is_chief:
nni.report_final_result(val_xent) nni.report_final_result(val_xent)
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() tf.app.run()
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
import logging import logging
import math import math
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
import nni import nni
...@@ -142,13 +143,21 @@ def bias_variable(shape): ...@@ -142,13 +143,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape) initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
''' '''
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
''' '''
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
......
...@@ -4,8 +4,9 @@ import argparse ...@@ -4,8 +4,9 @@ import argparse
import logging import logging
import math import math
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
import nni import nni
...@@ -143,13 +144,21 @@ def bias_variable(shape): ...@@ -143,13 +144,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape) initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
''' '''
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
''' '''
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
......
...@@ -3,8 +3,9 @@ import argparse ...@@ -3,8 +3,9 @@ import argparse
import logging import logging
import math import math
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None FLAGS = None
...@@ -143,13 +144,21 @@ def bias_variable(shape): ...@@ -143,13 +144,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape) initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
''' '''
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
''' '''
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
......
authorName: nni authorName: nni
experimentName: default_test experimentName: default_test
maxExecDuration: 5m maxExecDuration: 5m
maxTrialNum: 2 maxTrialNum: 4
trialConcurrency: 1 trialConcurrency: 2
tuner: tuner:
builtinTunerName: Random builtinTunerName: Random
......
authorName: nni authorName: nni
experimentName: default_test experimentName: default_test
maxExecDuration: 5m maxExecDuration: 5m
maxTrialNum: 2 maxTrialNum: 4
trialConcurrency: 1 trialConcurrency: 2
searchSpacePath: ./mnist_search_space.json searchSpacePath: ./mnist_search_space.json
tuner: tuner:
......
import nni
"""A deep MNIST classifier using convolutional layers.""" """A deep MNIST classifier using convolutional layers."""
import logging import logging
import math import math
import tempfile import tempfile
import time
import tensorflow as tf import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
import nni
FLAGS = None FLAGS = None
logger = logging.getLogger('mnist_AutoML') logger = logging.getLogger('mnist_AutoML')
...@@ -123,12 +127,23 @@ def bias_variable(shape): ...@@ -123,12 +127,23 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape) initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
""" """
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
""" """
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True)
def main(params):
# Import data
mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
mnist_network = MnistNetwork(channel_1_num=params['channel_1_num'], mnist_network = MnistNetwork(channel_1_num=params['channel_1_num'],
......
...@@ -21,8 +21,9 @@ ...@@ -21,8 +21,9 @@
import logging import logging
import math import math
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None FLAGS = None
...@@ -168,13 +169,21 @@ def bias_variable(shape): ...@@ -168,13 +169,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape) initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
''' '''
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
''' '''
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
......
...@@ -21,8 +21,9 @@ ...@@ -21,8 +21,9 @@
import logging import logging
import math import math
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
import nni import nni
...@@ -172,13 +173,21 @@ def bias_variable(shape): ...@@ -172,13 +173,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape) initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
''' '''
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
''' '''
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment