Commit 6172f113 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Standardize num steps per iteration.

PiperOrigin-RevId: 370720011
parent 36e9af47
...@@ -70,7 +70,8 @@ def _get_config_kwarg_overrides(): ...@@ -70,7 +70,8 @@ def _get_config_kwarg_overrides():
return { return {
'train_input_path': data_path, 'train_input_path': data_path,
'eval_input_path': data_path, 'eval_input_path': data_path,
'label_map_path': label_map_path 'label_map_path': label_map_path,
'train_input_reader': {'batch_size': 1}
} }
...@@ -98,6 +99,7 @@ class ModelLibTest(tf.test.TestCase): ...@@ -98,6 +99,7 @@ class ModelLibTest(tf.test.TestCase):
model_dir=model_dir, model_dir=model_dir,
train_steps=train_steps, train_steps=train_steps,
checkpoint_every_n=1, checkpoint_every_n=1,
num_steps_per_iteration=1,
**config_kwarg_overrides) **config_kwarg_overrides)
model_lib_v2.eval_continuously( model_lib_v2.eval_continuously(
...@@ -149,7 +151,7 @@ class SimpleModel(model.DetectionModel): ...@@ -149,7 +151,7 @@ class SimpleModel(model.DetectionModel):
def fake_model_builder(*_, **__): def fake_model_builder(*_, **__):
return SimpleModel() return SimpleModel()
FAKE_BUILDER_MAP = {'build': fake_model_builder} FAKE_BUILDER_MAP = {'detection_model_fn_base': fake_model_builder}
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
...@@ -161,7 +163,7 @@ class ModelCheckpointTest(tf.test.TestCase): ...@@ -161,7 +163,7 @@ class ModelCheckpointTest(tf.test.TestCase):
strategy = tf2.distribute.OneDeviceStrategy(device='/cpu:0') strategy = tf2.distribute.OneDeviceStrategy(device='/cpu:0')
with mock.patch.dict( with mock.patch.dict(
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP, FAKE_BUILDER_MAP): model_lib_v2.MODEL_BUILD_UTIL_MAP, FAKE_BUILDER_MAP):
model_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST) pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
...@@ -173,8 +175,8 @@ class ModelCheckpointTest(tf.test.TestCase): ...@@ -173,8 +175,8 @@ class ModelCheckpointTest(tf.test.TestCase):
with strategy.scope(): with strategy.scope():
model_lib_v2.train_loop( model_lib_v2.train_loop(
new_pipeline_config_path, model_dir=model_dir, new_pipeline_config_path, model_dir=model_dir,
train_steps=20, checkpoint_every_n=2, checkpoint_max_to_keep=3, train_steps=5, checkpoint_every_n=2, checkpoint_max_to_keep=3,
**config_kwarg_overrides num_steps_per_iteration=1, **config_kwarg_overrides
) )
ckpt_files = tf.io.gfile.glob(os.path.join(model_dir, 'ckpt-*.index')) ckpt_files = tf.io.gfile.glob(os.path.join(model_dir, 'ckpt-*.index'))
self.assertEqual(len(ckpt_files), 3, self.assertEqual(len(ckpt_files), 3,
...@@ -266,6 +268,7 @@ class MetricsExportTest(tf.test.TestCase): ...@@ -266,6 +268,7 @@ class MetricsExportTest(tf.test.TestCase):
train_steps=train_steps, train_steps=train_steps,
checkpoint_every_n=100, checkpoint_every_n=100,
performance_summary_exporter=export, performance_summary_exporter=export,
num_steps_per_iteration=1,
**_get_config_kwarg_overrides()) **_get_config_kwarg_overrides())
......
...@@ -39,6 +39,7 @@ from object_detection.utils import visualization_utils as vutils ...@@ -39,6 +39,7 @@ from object_detection.utils import visualization_utils as vutils
MODEL_BUILD_UTIL_MAP = model_lib.MODEL_BUILD_UTIL_MAP MODEL_BUILD_UTIL_MAP = model_lib.MODEL_BUILD_UTIL_MAP
NUM_STEPS_PER_ITERATION = 100
RESTORE_MAP_ERROR_TEMPLATE = ( RESTORE_MAP_ERROR_TEMPLATE = (
...@@ -442,6 +443,7 @@ def train_loop( ...@@ -442,6 +443,7 @@ def train_loop(
checkpoint_max_to_keep=7, checkpoint_max_to_keep=7,
record_summaries=True, record_summaries=True,
performance_summary_exporter=None, performance_summary_exporter=None,
num_steps_per_iteration=NUM_STEPS_PER_ITERATION,
**kwargs): **kwargs):
"""Trains a model using eager + functions. """Trains a model using eager + functions.
...@@ -473,6 +475,8 @@ def train_loop( ...@@ -473,6 +475,8 @@ def train_loop(
int, the number of most recent checkpoints to keep in the model directory. int, the number of most recent checkpoints to keep in the model directory.
record_summaries: Boolean, whether or not to record summaries. record_summaries: Boolean, whether or not to record summaries.
performance_summary_exporter: function for exporting performance metrics. performance_summary_exporter: function for exporting performance metrics.
num_steps_per_iteration: int, The number of training steps to perform
in each iteration.
**kwargs: Additional keyword arguments for configuration override. **kwargs: Additional keyword arguments for configuration override.
""" """
## Parse the configs ## Parse the configs
...@@ -577,13 +581,6 @@ def train_loop( ...@@ -577,13 +581,6 @@ def train_loop(
else: else:
summary_writer = tf2.summary.create_noop_writer() summary_writer = tf2.summary.create_noop_writer()
if use_tpu:
num_steps_per_iteration = 100
else:
# TODO(b/135933080) Explore setting to 100 when GPU performance issues
# are fixed.
num_steps_per_iteration = 1
with summary_writer.as_default(): with summary_writer.as_default():
with strategy.scope(): with strategy.scope():
with tf.compat.v2.summary.record_if( with tf.compat.v2.summary.record_if(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment