Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6172f113
Commit
6172f113
authored
Apr 27, 2021
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Apr 27, 2021
Browse files
Standardize num steps per iteration.
PiperOrigin-RevId: 370720011
parent
36e9af47
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
12 deletions
+12
-12
research/object_detection/model_lib_tf2_test.py
research/object_detection/model_lib_tf2_test.py
+8
-5
research/object_detection/model_lib_v2.py
research/object_detection/model_lib_v2.py
+4
-7
No files found.
research/object_detection/model_lib_tf2_test.py
View file @
6172f113
...
...
@@ -70,7 +70,8 @@ def _get_config_kwarg_overrides():
return
{
'train_input_path'
:
data_path
,
'eval_input_path'
:
data_path
,
'label_map_path'
:
label_map_path
'label_map_path'
:
label_map_path
,
'train_input_reader'
:
{
'batch_size'
:
1
}
}
...
...
@@ -98,6 +99,7 @@ class ModelLibTest(tf.test.TestCase):
model_dir
=
model_dir
,
train_steps
=
train_steps
,
checkpoint_every_n
=
1
,
num_steps_per_iteration
=
1
,
**
config_kwarg_overrides
)
model_lib_v2
.
eval_continuously
(
...
...
@@ -149,7 +151,7 @@ class SimpleModel(model.DetectionModel):
def
fake_model_builder
(
*
_
,
**
__
):
return
SimpleModel
()
FAKE_BUILDER_MAP
=
{
'
build
'
:
fake_model_builder
}
FAKE_BUILDER_MAP
=
{
'
detection_model_fn_base
'
:
fake_model_builder
}
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
...
...
@@ -161,7 +163,7 @@ class ModelCheckpointTest(tf.test.TestCase):
strategy
=
tf2
.
distribute
.
OneDeviceStrategy
(
device
=
'/cpu:0'
)
with
mock
.
patch
.
dict
(
exporter
_lib_v2
.
INPUT
_BUILD
ER
_UTIL_MAP
,
FAKE_BUILDER_MAP
):
model
_lib_v2
.
MODEL
_BUILD_UTIL_MAP
,
FAKE_BUILDER_MAP
):
model_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
pipeline_config_path
=
get_pipeline_config_path
(
MODEL_NAME_FOR_TEST
)
...
...
@@ -173,8 +175,8 @@ class ModelCheckpointTest(tf.test.TestCase):
with
strategy
.
scope
():
model_lib_v2
.
train_loop
(
new_pipeline_config_path
,
model_dir
=
model_dir
,
train_steps
=
20
,
checkpoint_every_n
=
2
,
checkpoint_max_to_keep
=
3
,
**
config_kwarg_overrides
train_steps
=
5
,
checkpoint_every_n
=
2
,
checkpoint_max_to_keep
=
3
,
num_steps_per_iteration
=
1
,
**
config_kwarg_overrides
)
ckpt_files
=
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
model_dir
,
'ckpt-*.index'
))
self
.
assertEqual
(
len
(
ckpt_files
),
3
,
...
...
@@ -266,6 +268,7 @@ class MetricsExportTest(tf.test.TestCase):
train_steps
=
train_steps
,
checkpoint_every_n
=
100
,
performance_summary_exporter
=
export
,
num_steps_per_iteration
=
1
,
**
_get_config_kwarg_overrides
())
...
...
research/object_detection/model_lib_v2.py
View file @
6172f113
...
...
@@ -39,6 +39,7 @@ from object_detection.utils import visualization_utils as vutils
MODEL_BUILD_UTIL_MAP
=
model_lib
.
MODEL_BUILD_UTIL_MAP
NUM_STEPS_PER_ITERATION
=
100
RESTORE_MAP_ERROR_TEMPLATE
=
(
...
...
@@ -442,6 +443,7 @@ def train_loop(
checkpoint_max_to_keep
=
7
,
record_summaries
=
True
,
performance_summary_exporter
=
None
,
num_steps_per_iteration
=
NUM_STEPS_PER_ITERATION
,
**
kwargs
):
"""Trains a model using eager + functions.
...
...
@@ -473,6 +475,8 @@ def train_loop(
int, the number of most recent checkpoints to keep in the model directory.
record_summaries: Boolean, whether or not to record summaries.
performance_summary_exporter: function for exporting performance metrics.
num_steps_per_iteration: int, The number of training steps to perform
in each iteration.
**kwargs: Additional keyword arguments for configuration override.
"""
## Parse the configs
...
...
@@ -577,13 +581,6 @@ def train_loop(
else
:
summary_writer
=
tf2
.
summary
.
create_noop_writer
()
if
use_tpu
:
num_steps_per_iteration
=
100
else
:
# TODO(b/135933080) Explore setting to 100 when GPU performance issues
# are fixed.
num_steps_per_iteration
=
1
with
summary_writer
.
as_default
():
with
strategy
.
scope
():
with
tf
.
compat
.
v2
.
summary
.
record_if
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment