Commit ee3bfa1e authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Add options in TF2 launch script for summaries and checkpoints.

PiperOrigin-RevId: 322828673
parent 2ae9c3a6
...@@ -23,6 +23,7 @@ import os ...@@ -23,6 +23,7 @@ import os
import time import time
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
from object_detection import eval_util from object_detection import eval_util
from object_detection import inputs from object_detection import inputs
...@@ -414,8 +415,9 @@ def train_loop( ...@@ -414,8 +415,9 @@ def train_loop(
train_steps=None, train_steps=None,
use_tpu=False, use_tpu=False,
save_final_config=False, save_final_config=False,
checkpoint_every_n=5000, checkpoint_every_n=1000,
checkpoint_max_to_keep=7, checkpoint_max_to_keep=7,
record_summaries=True,
**kwargs): **kwargs):
"""Trains a model using eager + functions. """Trains a model using eager + functions.
...@@ -445,6 +447,7 @@ def train_loop( ...@@ -445,6 +447,7 @@ def train_loop(
Checkpoint every n training steps. Checkpoint every n training steps.
checkpoint_max_to_keep: checkpoint_max_to_keep:
int, the number of most recent checkpoints to keep in the model directory. int, the number of most recent checkpoints to keep in the model directory.
record_summaries: Boolean, whether or not to record summaries.
**kwargs: Additional keyword arguments for configuration override. **kwargs: Additional keyword arguments for configuration override.
""" """
## Parse the configs ## Parse the configs
...@@ -531,8 +534,11 @@ def train_loop( ...@@ -531,8 +534,11 @@ def train_loop(
# is the chief. # is the chief.
summary_writer_filepath = get_filepath(strategy, summary_writer_filepath = get_filepath(strategy,
os.path.join(model_dir, 'train')) os.path.join(model_dir, 'train'))
summary_writer = tf.compat.v2.summary.create_file_writer( if record_summaries:
summary_writer_filepath) summary_writer = tf.compat.v2.summary.create_file_writer(
summary_writer_filepath)
else:
summary_writer = tf2.summary.create_noop_writer()
if use_tpu: if use_tpu:
num_steps_per_iteration = 100 num_steps_per_iteration = 100
...@@ -604,7 +610,9 @@ def train_loop( ...@@ -604,7 +610,9 @@ def train_loop(
if num_steps_per_iteration > 1: if num_steps_per_iteration > 1:
for _ in tf.range(num_steps_per_iteration - 1): for _ in tf.range(num_steps_per_iteration - 1):
_sample_and_train(strategy, train_step_fn, data_iterator) # Following suggestion on yaqs/5402607292645376
with tf.name_scope(''):
_sample_and_train(strategy, train_step_fn, data_iterator)
return _sample_and_train(strategy, train_step_fn, data_iterator) return _sample_and_train(strategy, train_step_fn, data_iterator)
......
...@@ -62,6 +62,11 @@ flags.DEFINE_integer( ...@@ -62,6 +62,11 @@ flags.DEFINE_integer(
'num_workers', 1, 'When num_workers > 1, training uses ' 'num_workers', 1, 'When num_workers > 1, training uses '
'MultiWorkerMirroredStrategy. When num_workers = 1 it uses ' 'MultiWorkerMirroredStrategy. When num_workers = 1 it uses '
'MirroredStrategy.') 'MirroredStrategy.')
flags.DEFINE_integer(
'checkpoint_every_n', 1000, 'Integer defining how often we checkpoint.')
flags.DEFINE_boolean('record_summaries', True,
('Whether or not to record summaries during'
' training.'))
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -100,7 +105,9 @@ def main(unused_argv): ...@@ -100,7 +105,9 @@ def main(unused_argv):
pipeline_config_path=FLAGS.pipeline_config_path, pipeline_config_path=FLAGS.pipeline_config_path,
model_dir=FLAGS.model_dir, model_dir=FLAGS.model_dir,
train_steps=FLAGS.num_train_steps, train_steps=FLAGS.num_train_steps,
use_tpu=FLAGS.use_tpu) use_tpu=FLAGS.use_tpu,
checkpoint_every_n=FLAGS.checkpoint_every_n,
record_summaries=FLAGS.record_summaries)
if __name__ == '__main__': if __name__ == '__main__':
tf.compat.v1.app.run() tf.compat.v1.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment