Commit ebc97a77 authored by Jiang,Zhoulong's avatar Jiang,Zhoulong
Browse files

add intra/inter op support for dcgan model

parent fb6bc29b
...@@ -66,6 +66,13 @@ flags.DEFINE_integer('max_number_of_evaluations', None, ...@@ -66,6 +66,13 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
flags.DEFINE_boolean('write_to_disk', True, 'If `True`, run images to disk.') flags.DEFINE_boolean('write_to_disk', True, 'If `True`, run images to disk.')
flags.DEFINE_integer(
'inter_op_parallelism_threads', 0,
'Number of threads to use for inter-op parallelism. If left as default value of 0, the system will pick an appropriate number.')
flags.DEFINE_integer(
'intra_op_parallelism_threads', 0,
'Number of threads to use for intra-op parallelism. If left as default value of 0, the system will pick an appropriate number.')
def main(_, run_eval_loop=True): def main(_, run_eval_loop=True):
# Fetch and generate images to run through Inception. # Fetch and generate images to run through Inception.
...@@ -119,12 +126,16 @@ def main(_, run_eval_loop=True): ...@@ -119,12 +126,16 @@ def main(_, run_eval_loop=True):
# For unit testing, use `run_eval_loop=False`. # For unit testing, use `run_eval_loop=False`.
if not run_eval_loop: return if not run_eval_loop: return
sess_config = tf.ConfigProto(
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)
tf.contrib.training.evaluate_repeatedly( tf.contrib.training.evaluate_repeatedly(
FLAGS.checkpoint_dir, FLAGS.checkpoint_dir,
master=FLAGS.master, master=FLAGS.master,
hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir), hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
tf.contrib.training.StopAfterNEvalsHook(1)], tf.contrib.training.StopAfterNEvalsHook(1)],
eval_ops=image_write_ops, eval_ops=image_write_ops,
config=sess_config,
max_number_of_evaluations=FLAGS.max_number_of_evaluations) max_number_of_evaluations=FLAGS.max_number_of_evaluations)
......
...@@ -68,6 +68,13 @@ flags.DEFINE_integer( ...@@ -68,6 +68,13 @@ flags.DEFINE_integer(
'backup_workers', 1, 'backup_workers', 1,
'Number of workers to be kept as backup in the sync replicas case.') 'Number of workers to be kept as backup in the sync replicas case.')
flags.DEFINE_integer(
'inter_op_parallelism_threads', 0,
'Number of threads to use for inter-op parallelism. If left as default value of 0, the system will pick an appropriate number.')
flags.DEFINE_integer(
'intra_op_parallelism_threads', 0,
'Number of threads to use for intra-op parallelism. If left as default value of 0, the system will pick an appropriate number.')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -134,6 +141,9 @@ def main(_): ...@@ -134,6 +141,9 @@ def main(_):
tf.as_string(tf.train.get_or_create_global_step())], tf.as_string(tf.train.get_or_create_global_step())],
name='status_message') name='status_message')
if FLAGS.max_number_of_steps == 0: return if FLAGS.max_number_of_steps == 0: return
sess_config = tf.ConfigProto(
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)
tfgan.gan_train( tfgan.gan_train(
train_ops, train_ops,
hooks=( hooks=(
...@@ -142,7 +152,8 @@ def main(_): ...@@ -142,7 +152,8 @@ def main(_):
sync_hooks), sync_hooks),
logdir=FLAGS.train_log_dir, logdir=FLAGS.train_log_dir,
master=FLAGS.master, master=FLAGS.master,
is_chief=FLAGS.task == 0) is_chief=FLAGS.task == 0,
config=sess_config)
def _learning_rate(): def _learning_rate():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment