Unverified Commit c2f755c4 authored by Joel Shor's avatar Joel Shor Committed by GitHub
Browse files

Merge pull request #5086 from Intel-tensorflow/pr-set-inter-intra-op-dcgan

Add intra/inter op support for dcgan model
parents 5be37277 ebc97a77
......@@ -66,6 +66,13 @@ flags.DEFINE_integer('max_number_of_evaluations', None,
flags.DEFINE_boolean('write_to_disk', True, 'If `True`, run images to disk.')
flags.DEFINE_integer(
'inter_op_parallelism_threads', 0,
'Number of threads to use for inter-op parallelism. If left as default value of 0, the system will pick an appropriate number.')
flags.DEFINE_integer(
'intra_op_parallelism_threads', 0,
'Number of threads to use for intra-op parallelism. If left as default value of 0, the system will pick an appropriate number.')
def main(_, run_eval_loop=True):
# Fetch and generate images to run through Inception.
......@@ -119,12 +126,16 @@ def main(_, run_eval_loop=True):
# For unit testing, use `run_eval_loop=False`.
if not run_eval_loop: return
sess_config = tf.ConfigProto(
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)
tf.contrib.training.evaluate_repeatedly(
FLAGS.checkpoint_dir,
master=FLAGS.master,
hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
tf.contrib.training.StopAfterNEvalsHook(1)],
eval_ops=image_write_ops,
config=sess_config,
max_number_of_evaluations=FLAGS.max_number_of_evaluations)
......
......@@ -68,6 +68,13 @@ flags.DEFINE_integer(
'backup_workers', 1,
'Number of workers to be kept as backup in the sync replicas case.')
flags.DEFINE_integer(
'inter_op_parallelism_threads', 0,
'Number of threads to use for inter-op parallelism. If left as default value of 0, the system will pick an appropriate number.')
flags.DEFINE_integer(
'intra_op_parallelism_threads', 0,
'Number of threads to use for intra-op parallelism. If left as default value of 0, the system will pick an appropriate number.')
FLAGS = flags.FLAGS
......@@ -134,6 +141,9 @@ def main(_):
tf.as_string(tf.train.get_or_create_global_step())],
name='status_message')
if FLAGS.max_number_of_steps == 0: return
sess_config = tf.ConfigProto(
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)
tfgan.gan_train(
train_ops,
hooks=(
......@@ -142,7 +152,8 @@ def main(_):
sync_hooks),
logdir=FLAGS.train_log_dir,
master=FLAGS.master,
is_chief=FLAGS.task == 0)
is_chief=FLAGS.task == 0,
config=sess_config)
def _learning_rate():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment