Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ee3bfa1e
Commit
ee3bfa1e
authored
Jul 23, 2020
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Jul 23, 2020
Browse files
Add options in TF2 launch script for summaries and checkpoints.
PiperOrigin-RevId: 322828673
parent
2ae9c3a6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
5 deletions
+20
-5
research/object_detection/model_lib_v2.py
research/object_detection/model_lib_v2.py
+12
-4
research/object_detection/model_main_tf2.py
research/object_detection/model_main_tf2.py
+8
-1
No files found.
research/object_detection/model_lib_v2.py
View file @
ee3bfa1e
...
...
@@ -23,6 +23,7 @@ import os
import
time
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v2
as
tf2
from
object_detection
import
eval_util
from
object_detection
import
inputs
...
...
@@ -414,8 +415,9 @@ def train_loop(
train_steps
=
None
,
use_tpu
=
False
,
save_final_config
=
False
,
checkpoint_every_n
=
5
000
,
checkpoint_every_n
=
1
000
,
checkpoint_max_to_keep
=
7
,
record_summaries
=
True
,
**
kwargs
):
"""Trains a model using eager + functions.
...
...
@@ -445,6 +447,7 @@ def train_loop(
Checkpoint every n training steps.
checkpoint_max_to_keep:
int, the number of most recent checkpoints to keep in the model directory.
record_summaries: Boolean, whether or not to record summaries.
**kwargs: Additional keyword arguments for configuration override.
"""
## Parse the configs
...
...
@@ -531,8 +534,11 @@ def train_loop(
# is the chief.
summary_writer_filepath
=
get_filepath
(
strategy
,
os
.
path
.
join
(
model_dir
,
'train'
))
if
record_summaries
:
summary_writer
=
tf
.
compat
.
v2
.
summary
.
create_file_writer
(
summary_writer_filepath
)
else
:
summary_writer
=
tf2
.
summary
.
create_noop_writer
()
if
use_tpu
:
num_steps_per_iteration
=
100
...
...
@@ -604,6 +610,8 @@ def train_loop(
if
num_steps_per_iteration
>
1
:
for
_
in
tf
.
range
(
num_steps_per_iteration
-
1
):
# Following suggestion on yaqs/5402607292645376
with
tf
.
name_scope
(
''
):
_sample_and_train
(
strategy
,
train_step_fn
,
data_iterator
)
return
_sample_and_train
(
strategy
,
train_step_fn
,
data_iterator
)
...
...
research/object_detection/model_main_tf2.py
View file @
ee3bfa1e
...
...
@@ -62,6 +62,11 @@ flags.DEFINE_integer(
'num_workers'
,
1
,
'When num_workers > 1, training uses '
'MultiWorkerMirroredStrategy. When num_workers = 1 it uses '
'MirroredStrategy.'
)
flags
.
DEFINE_integer
(
'checkpoint_every_n'
,
1000
,
'Integer defining how often we checkpoint.'
)
flags
.
DEFINE_boolean
(
'record_summaries'
,
True
,
(
'Whether or not to record summaries during'
' training.'
))
FLAGS
=
flags
.
FLAGS
...
...
@@ -100,7 +105,9 @@ def main(unused_argv):
pipeline_config_path
=
FLAGS
.
pipeline_config_path
,
model_dir
=
FLAGS
.
model_dir
,
train_steps
=
FLAGS
.
num_train_steps
,
use_tpu
=
FLAGS
.
use_tpu
)
use_tpu
=
FLAGS
.
use_tpu
,
checkpoint_every_n
=
FLAGS
.
checkpoint_every_n
,
record_summaries
=
FLAGS
.
record_summaries
)
if
__name__
==
'__main__'
:
tf
.
compat
.
v1
.
app
.
run
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment