Commit ee3bfa1e authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Add options in TF2 launch script for summaries and checkpoints.

PiperOrigin-RevId: 322828673
parent 2ae9c3a6
......@@ -23,6 +23,7 @@ import os
import time
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
from object_detection import eval_util
from object_detection import inputs
......@@ -414,8 +415,9 @@ def train_loop(
train_steps=None,
use_tpu=False,
save_final_config=False,
checkpoint_every_n=5000,
checkpoint_every_n=1000,
checkpoint_max_to_keep=7,
record_summaries=True,
**kwargs):
"""Trains a model using eager + functions.
......@@ -445,6 +447,7 @@ def train_loop(
Checkpoint every n training steps.
checkpoint_max_to_keep:
int, the number of most recent checkpoints to keep in the model directory.
record_summaries: Boolean, whether or not to record summaries.
**kwargs: Additional keyword arguments for configuration override.
"""
## Parse the configs
......@@ -531,8 +534,11 @@ def train_loop(
# is the chief.
summary_writer_filepath = get_filepath(strategy,
os.path.join(model_dir, 'train'))
if record_summaries:
summary_writer = tf.compat.v2.summary.create_file_writer(
summary_writer_filepath)
else:
summary_writer = tf2.summary.create_noop_writer()
if use_tpu:
num_steps_per_iteration = 100
......@@ -604,6 +610,8 @@ def train_loop(
if num_steps_per_iteration > 1:
for _ in tf.range(num_steps_per_iteration - 1):
# Following suggestion on yaqs/5402607292645376
with tf.name_scope(''):
_sample_and_train(strategy, train_step_fn, data_iterator)
return _sample_and_train(strategy, train_step_fn, data_iterator)
......
......@@ -62,6 +62,11 @@ flags.DEFINE_integer(
'num_workers', 1, 'When num_workers > 1, training uses '
'MultiWorkerMirroredStrategy. When num_workers = 1 it uses '
'MirroredStrategy.')
flags.DEFINE_integer(
'checkpoint_every_n', 1000, 'Integer defining how often we checkpoint.')
flags.DEFINE_boolean('record_summaries', True,
('Whether or not to record summaries during'
' training.'))
FLAGS = flags.FLAGS
......@@ -100,7 +105,9 @@ def main(unused_argv):
pipeline_config_path=FLAGS.pipeline_config_path,
model_dir=FLAGS.model_dir,
train_steps=FLAGS.num_train_steps,
use_tpu=FLAGS.use_tpu)
use_tpu=FLAGS.use_tpu,
checkpoint_every_n=FLAGS.checkpoint_every_n,
record_summaries=FLAGS.record_summaries)
if __name__ == '__main__':
tf.compat.v1.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment