"docs/en_US/vscode:/vscode.git/clone" did not exist on "ba42b6a44de409494e2fe2619b468a56543f190a"
Commit 2b4a7d1f authored by Lee's avatar Lee Committed by chicm-ms
Browse files

Fix tensorflow MNIST dataset download bug in multiple trials (#815)

* Fix tensorflow MNIST dataset download bug in multiple trials

* Modify the maxTrialNum and trialConcurrency in config
parent 2fa77bcc
......@@ -4,8 +4,9 @@ import argparse
import logging
import math
import tempfile
import tensorflow as tf
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None
......@@ -152,12 +153,21 @@ def bias_variable(shape):
return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params):
'''
Main function, build mnist network, run and send result to NNI.
'''
# Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True)
mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.')
logger.debug('Mnist download data done.')
......
'''
mnist.py is an example to show: how to use iterative search space to tune architecture network for mnist.
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function
import argparse
import codecs
import json
import logging
import math
import sys
import tempfile
import tensorflow as tf
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import nni
logger = logging.getLogger('mnist_cascading_search_space')
FLAGS = None
......@@ -95,9 +89,18 @@ class MnistNetwork(object):
child_accuracy = tf.equal(tf.argmax(output_layer, 1), tf.argmax(self.y, 1))
self.accuracy = tf.reduce_mean(tf.cast(child_accuracy, "float")) # add a reduce_mean
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params):
# Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True)
mnist = download_mnist_retry(params['data_dir'])
# Create the model
# Build the graph for the deep net
......
......@@ -54,16 +54,19 @@ import nni
flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/mnist-data",
"Directory for storing mnist data")
flags.DEFINE_boolean("download_only", False,
flags.DEFINE_boolean(
"download_only", False,
"Only perform downloading of data; Do not proceed to "
"session preparation, model definition or training")
flags.DEFINE_integer("task_index", None,
"Worker task index, should be >= 0. task_index=0 is "
flags.DEFINE_integer(
"task_index", None, "Worker task index, should be >= 0. task_index=0 is "
"the master worker task the performs the variable "
"initialization ")
flags.DEFINE_integer("num_gpus", 1, "Total number of gpus for each machine."
flags.DEFINE_integer(
"num_gpus", 1, "Total number of gpus for each machine."
"If you don't use GPU, please set it to '0'")
flags.DEFINE_integer("replicas_to_aggregate", None,
flags.DEFINE_integer(
"replicas_to_aggregate", None,
"Number of replicas to aggregate before parameter update"
"is applied (For sync_replicas mode only; default: "
"num_workers)")
......@@ -96,6 +99,7 @@ IMAGE_PIXELS = 28
# {'cluster': cluster,
# 'task': {'type': 'worker', 'index': 1}})
def generate_default_params():
'''
Generate default hyper parameters
......@@ -106,6 +110,15 @@ def generate_default_params():
'hidden_units': 100,
}
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(unused_argv):
# Receive NNI hyper parameter and update it onto default params
RECEIVED_PARAMS = nni.get_next_parameter()
......@@ -124,7 +137,7 @@ def main(unused_argv):
FLAGS.job_name = task_type
FLAGS.task_index = task_index
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
mnist = download_mnist_retry(FLAGS.data_dir)
if FLAGS.download_only:
sys.exit(0)
......@@ -206,7 +219,8 @@ def main(unused_argv):
hid = tf.nn.relu(hid_lin)
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
cross_entropy = -tf.reduce_sum(
y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
opt = tf.train.AdamOptimizer(PARAMS['learning_rate'])
......@@ -258,8 +272,9 @@ def main(unused_argv):
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False,
device_filters=["/job:ps",
"/job:worker/task:%d" % FLAGS.task_index])
device_filters=[
"/job:ps", "/job:worker/task:%d" % FLAGS.task_index
])
# The chief worker (task_index==0) session will prepare the session,
# while the remaining workers will wait for the preparation to complete.
......@@ -273,9 +288,11 @@ def main(unused_argv):
server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
print("Using existing server at: %s" % server_grpc_url)
sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config)
sess = sv.prepare_or_wait_for_session(
server_grpc_url, config=sess_config)
else:
sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
sess = sv.prepare_or_wait_for_session(
server.target, config=sess_config)
print("Worker %d: Session initialization complete." % FLAGS.task_index)
......@@ -302,9 +319,14 @@ def main(unused_argv):
(now, FLAGS.task_index, local_step, step))
if step > 0 and step % 5000 == 0 and is_chief:
val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
val_feed = {
x: mnist.validation.images,
y_: mnist.validation.labels
}
interim_val_xent = sess.run(cross_entropy, feed_dict=val_feed)
print("After %d training step(s), validation cross entropy = %g" % (step, interim_val_xent))
print(
"After %d training step(s), validation cross entropy = %g"
% (step, interim_val_xent))
# Only chief worker can report intermediate metrics
nni.report_intermediate_result(interim_val_xent)
......
......@@ -3,8 +3,9 @@
import logging
import math
import tempfile
import tensorflow as tf
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import nni
......@@ -142,13 +143,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params):
'''
Main function, build mnist network, run and send result to NNI.
'''
# Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True)
mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.')
logger.debug('Mnist download data done.')
......
......@@ -4,8 +4,9 @@ import argparse
import logging
import math
import tempfile
import tensorflow as tf
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import nni
......@@ -143,13 +144,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params):
'''
Main function, build mnist network, run and send result to NNI.
'''
# Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True)
mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.')
logger.debug('Mnist download data done.')
......
......@@ -3,8 +3,9 @@ import argparse
import logging
import math
import tempfile
import tensorflow as tf
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None
......@@ -143,13 +144,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params):
'''
Main function, build mnist network, run and send result to NNI.
'''
# Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True)
mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.')
logger.debug('Mnist download data done.')
......
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 2
trialConcurrency: 1
maxTrialNum: 4
trialConcurrency: 2
tuner:
builtinTunerName: Random
......
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 2
trialConcurrency: 1
maxTrialNum: 4
trialConcurrency: 2
searchSpacePath: ./mnist_search_space.json
tuner:
......
import nni
"""A deep MNIST classifier using convolutional layers."""
import logging
import math
import tempfile
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import nni
FLAGS = None
logger = logging.getLogger('mnist_AutoML')
......@@ -123,12 +127,23 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params):
"""
Main function, build mnist network, run and send result to NNI.
"""
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True)
def main(params):
# Import data
mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.')
logger.debug('Mnist download data done.')
mnist_network = MnistNetwork(channel_1_num=params['channel_1_num'],
......
......@@ -21,8 +21,9 @@
import logging
import math
import tempfile
import tensorflow as tf
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None
......@@ -168,13 +169,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params):
'''
Main function, build mnist network, run and send result to NNI.
'''
# Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True)
mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.')
logger.debug('Mnist download data done.')
......
......@@ -21,8 +21,9 @@
import logging
import math
import tempfile
import tensorflow as tf
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import nni
......@@ -172,13 +173,21 @@ def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params):
'''
Main function, build mnist network, run and send result to NNI.
'''
# Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True)
mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.')
logger.debug('Mnist download data done.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment