Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1b77cd80
Commit
1b77cd80
authored
Oct 16, 2019
by
Yeqing Li
Committed by
A. Unique TensorFlower
Oct 16, 2019
Browse files
Enables timer callback and disables checkpoint saving in retinanet benchmark test.
PiperOrigin-RevId: 275080469
parent
cb913691
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
12 deletions
+34
-12
official/benchmark/retinanet_benchmark.py
official/benchmark/retinanet_benchmark.py
+7
-6
official/modeling/training/distributed_executor.py
official/modeling/training/distributed_executor.py
+11
-1
official/vision/detection/main.py
official/vision/detection/main.py
+16
-5
No files found.
official/benchmark/retinanet_benchmark.py
View file @
1b77cd80
...
...
@@ -95,10 +95,8 @@ class DetectionBenchmarkBase(tf.test.Benchmark):
}]
if
self
.
timer_callback
:
metrics
.
append
({
'name'
:
'exp_per_second'
,
'value'
:
self
.
timer_callback
.
get_examples_per_sec
(
FLAGS
.
train_batch_size
)
'name'
:
'exp_per_second'
,
'value'
:
self
.
timer_callback
.
get_examples_per_sec
(
train_batch_size
)
})
else
:
metrics
.
append
({
...
...
@@ -134,7 +132,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
def
_run_detection_main
(
self
):
"""Starts detection job."""
return
detection
.
main
(
'unused_argv'
)
return
detection
.
run
(
callbacks
=
[
self
.
timer_callback
]
)
class
RetinanetAccuracy
(
RetinanetBenchmarkBase
):
...
...
@@ -166,7 +164,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
stats
=
summary
,
wall_time_sec
=
wall_time_sec
,
min_ap
=
min_ap
,
max_ap
=
max_ap
)
max_ap
=
max_ap
,
train_batch_size
=
self
.
params_override
[
'train'
][
'batch_size'
])
def
_setup
(
self
):
super
(
RetinanetAccuracy
,
self
).
_setup
()
...
...
@@ -228,6 +227,8 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
params
[
'eval'
][
'eval_samples'
]
=
8
FLAGS
.
params_override
=
json
.
dumps
(
params
)
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'real_benchmark_8_gpu_coco'
)
# Use negative value to avoid saving checkpoints.
FLAGS
.
save_checkpoint_freq
=
-
1
if
self
.
timer_callback
is
None
:
logging
.
error
(
'Cannot measure performance without timer callback'
)
else
:
...
...
official/modeling/training/distributed_executor.py
View file @
1b77cd80
...
...
@@ -103,6 +103,8 @@ def initialize_common_flags():
flags
.
DEFINE_integer
(
'task_index'
,
0
,
'If multi-worker training, the task_index of this worker.'
)
flags
.
DEFINE_integer
(
'save_checkpoint_freq'
,
None
,
'Number of steps to save checkpoint.'
)
def
strategy_flags_dict
():
...
...
@@ -447,6 +449,12 @@ class DistributedExecutor(object):
if
save_config
:
self
.
_save_config
(
model_dir
)
if
FLAGS
.
save_checkpoint_freq
:
save_freq
=
FLAGS
.
save_checkpoint_freq
else
:
save_freq
=
iterations_per_loop
last_save_checkpoint_step
=
0
params
=
self
.
_params
strategy
=
self
.
_strategy
# To reduce unnecessary send/receive input pipeline operation, we place
...
...
@@ -540,9 +548,11 @@ class DistributedExecutor(object):
# iterations_per_loop steps.
# To avoid repeated model saving, we do not save after the last
# step of training.
if
current_step
<
total_steps
:
if
save_freq
>
0
and
current_step
<
total_steps
and
(
current_step
-
last_save_checkpoint_step
)
>=
save_freq
:
_save_checkpoint
(
checkpoint
,
model_dir
,
checkpoint_name
.
format
(
step
=
current_step
))
last_save_checkpoint_step
=
current_step
if
test_step
:
eval_iterator
=
self
.
_get_input_iterator
(
eval_input_fn
,
strategy
)
...
...
official/vision/detection/main.py
View file @
1b77cd80
...
...
@@ -55,7 +55,10 @@ flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
FLAGS
=
flags
.
FLAGS
def
run_executor
(
params
,
train_input_fn
=
None
,
eval_input_fn
=
None
):
def
run_executor
(
params
,
train_input_fn
=
None
,
eval_input_fn
=
None
,
callbacks
=
None
):
"""Runs Retinanet model on distribution strategy defined by the user."""
model_builder
=
model_factory
.
model_generator
(
params
)
...
...
@@ -92,6 +95,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
iterations_per_loop
=
params
.
train
.
iterations_per_loop
,
total_steps
=
params
.
train
.
total_steps
,
init_checkpoint
=
model_builder
.
make_restore_checkpoint_fn
(),
custom_callbacks
=
callbacks
,
save_config
=
True
)
elif
FLAGS
.
mode
==
'eval'
:
...
...
@@ -124,9 +128,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
raise
ValueError
(
'Mode not found: %s.'
%
FLAGS
.
mode
)
def
main
(
argv
):
del
argv
# Unused.
def
run
(
callbacks
=
None
):
params
=
config_factory
.
config_generator
(
FLAGS
.
model
)
params
=
params_dict
.
override_params_dict
(
...
...
@@ -171,7 +173,16 @@ def main(argv):
batch_size
=
params
.
eval
.
batch_size
,
num_examples
=
params
.
eval
.
eval_samples
)
return
run_executor
(
params
,
train_input_fn
=
train_input_fn
,
eval_input_fn
=
eval_input_fn
)
params
,
train_input_fn
=
train_input_fn
,
eval_input_fn
=
eval_input_fn
,
callbacks
=
callbacks
)
def
main
(
argv
):
del
argv
# Unused.
return
run
()
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment