Unverified Commit d4a4dd04 authored by Karmel Allison's avatar Karmel Allison Committed by GitHub
Browse files

Adding thread args back in, with allow_soft_placement (#3533)

parent e029542a
...@@ -609,8 +609,18 @@ def resnet_main(flags, model_function, input_function): ...@@ -609,8 +609,18 @@ def resnet_main(flags, model_function, input_function):
model_function, model_function,
loss_reduction=tf.losses.Reduction.MEAN) loss_reduction=tf.losses.Reduction.MEAN)
# Set up a RunConfig to only save checkpoints once per training cycle. # Create session config based on values of inter_op_parallelism_threads and
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9) # intra_op_parallelism_threads. Note that we default to having
# allow_soft_placement = True, which is required for multi-GPU and not
# harmful for other modes.
session_config = tf.ConfigProto(
inter_op_parallelism_threads=flags.inter_op_parallelism_threads,
intra_op_parallelism_threads=flags.intra_op_parallelism_threads,
allow_soft_placement=True)
# Set up a RunConfig to save checkpoint and set session config.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9,
session_config=session_config)
classifier = tf.estimator.Estimator( classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags.model_dir, config=run_config, model_fn=model_function, model_dir=flags.model_dir, config=run_config,
params={ params={
...@@ -706,3 +716,13 @@ class ResnetArgParser(argparse.ArgumentParser): ...@@ -706,3 +716,13 @@ class ResnetArgParser(argparse.ArgumentParser):
help='If set, use fake data (zeroes) instead of a real dataset. ' help='If set, use fake data (zeroes) instead of a real dataset. '
'This mode is useful for performance debugging, as it removes ' 'This mode is useful for performance debugging, as it removes '
'input processing steps, but will not learn anything.') 'input processing steps, but will not learn anything.')
self.add_argument(
'--inter_op_parallelism_threads', type=int, default=0,
help='Number of inter_op_parallelism_threads to use for CPU. '
'See TensorFlow config.proto for details.')
self.add_argument(
'--intra_op_parallelism_threads', type=int, default=0,
help='Number of intra_op_parallelism_threads to use for CPU. '
'See TensorFlow config.proto for details.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment