Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1b77cd80
Commit
1b77cd80
authored
Oct 16, 2019
by
Yeqing Li
Committed by
A. Unique TensorFlower
Oct 16, 2019
Browse files
Enables timer callback and disables checkpoint saving in retinanet benchmark test.
PiperOrigin-RevId: 275080469
parent
cb913691
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
12 deletions
+34
-12
official/benchmark/retinanet_benchmark.py
official/benchmark/retinanet_benchmark.py
+7
-6
official/modeling/training/distributed_executor.py
official/modeling/training/distributed_executor.py
+11
-1
official/vision/detection/main.py
official/vision/detection/main.py
+16
-5
No files found.
official/benchmark/retinanet_benchmark.py
View file @
1b77cd80
...
@@ -95,10 +95,8 @@ class DetectionBenchmarkBase(tf.test.Benchmark):
...
@@ -95,10 +95,8 @@ class DetectionBenchmarkBase(tf.test.Benchmark):
}]
}]
if
self
.
timer_callback
:
if
self
.
timer_callback
:
metrics
.
append
({
metrics
.
append
({
'name'
:
'name'
:
'exp_per_second'
,
'exp_per_second'
,
'value'
:
self
.
timer_callback
.
get_examples_per_sec
(
train_batch_size
)
'value'
:
self
.
timer_callback
.
get_examples_per_sec
(
FLAGS
.
train_batch_size
)
})
})
else
:
else
:
metrics
.
append
({
metrics
.
append
({
...
@@ -134,7 +132,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
...
@@ -134,7 +132,7 @@ class RetinanetBenchmarkBase(DetectionBenchmarkBase):
def
_run_detection_main
(
self
):
def
_run_detection_main
(
self
):
"""Starts detection job."""
"""Starts detection job."""
return
detection
.
main
(
'unused_argv'
)
return
detection
.
run
(
callbacks
=
[
self
.
timer_callback
]
)
class
RetinanetAccuracy
(
RetinanetBenchmarkBase
):
class
RetinanetAccuracy
(
RetinanetBenchmarkBase
):
...
@@ -166,7 +164,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
...
@@ -166,7 +164,8 @@ class RetinanetAccuracy(RetinanetBenchmarkBase):
stats
=
summary
,
stats
=
summary
,
wall_time_sec
=
wall_time_sec
,
wall_time_sec
=
wall_time_sec
,
min_ap
=
min_ap
,
min_ap
=
min_ap
,
max_ap
=
max_ap
)
max_ap
=
max_ap
,
train_batch_size
=
self
.
params_override
[
'train'
][
'batch_size'
])
def
_setup
(
self
):
def
_setup
(
self
):
super
(
RetinanetAccuracy
,
self
).
_setup
()
super
(
RetinanetAccuracy
,
self
).
_setup
()
...
@@ -228,6 +227,8 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
...
@@ -228,6 +227,8 @@ class RetinanetBenchmarkReal(RetinanetAccuracy):
params
[
'eval'
][
'eval_samples'
]
=
8
params
[
'eval'
][
'eval_samples'
]
=
8
FLAGS
.
params_override
=
json
.
dumps
(
params
)
FLAGS
.
params_override
=
json
.
dumps
(
params
)
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'real_benchmark_8_gpu_coco'
)
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'real_benchmark_8_gpu_coco'
)
# Use negative value to avoid saving checkpoints.
FLAGS
.
save_checkpoint_freq
=
-
1
if
self
.
timer_callback
is
None
:
if
self
.
timer_callback
is
None
:
logging
.
error
(
'Cannot measure performance without timer callback'
)
logging
.
error
(
'Cannot measure performance without timer callback'
)
else
:
else
:
...
...
official/modeling/training/distributed_executor.py
View file @
1b77cd80
...
@@ -103,6 +103,8 @@ def initialize_common_flags():
...
@@ -103,6 +103,8 @@ def initialize_common_flags():
flags
.
DEFINE_integer
(
flags
.
DEFINE_integer
(
'task_index'
,
0
,
'task_index'
,
0
,
'If multi-worker training, the task_index of this worker.'
)
'If multi-worker training, the task_index of this worker.'
)
flags
.
DEFINE_integer
(
'save_checkpoint_freq'
,
None
,
'Number of steps to save checkpoint.'
)
def
strategy_flags_dict
():
def
strategy_flags_dict
():
...
@@ -447,6 +449,12 @@ class DistributedExecutor(object):
...
@@ -447,6 +449,12 @@ class DistributedExecutor(object):
if
save_config
:
if
save_config
:
self
.
_save_config
(
model_dir
)
self
.
_save_config
(
model_dir
)
if
FLAGS
.
save_checkpoint_freq
:
save_freq
=
FLAGS
.
save_checkpoint_freq
else
:
save_freq
=
iterations_per_loop
last_save_checkpoint_step
=
0
params
=
self
.
_params
params
=
self
.
_params
strategy
=
self
.
_strategy
strategy
=
self
.
_strategy
# To reduce unnecessary send/receive input pipeline operation, we place
# To reduce unnecessary send/receive input pipeline operation, we place
...
@@ -540,9 +548,11 @@ class DistributedExecutor(object):
...
@@ -540,9 +548,11 @@ class DistributedExecutor(object):
# iterations_per_loop steps.
# iterations_per_loop steps.
# To avoid repeated model saving, we do not save after the last
# To avoid repeated model saving, we do not save after the last
# step of training.
# step of training.
if
current_step
<
total_steps
:
if
save_freq
>
0
and
current_step
<
total_steps
and
(
current_step
-
last_save_checkpoint_step
)
>=
save_freq
:
_save_checkpoint
(
checkpoint
,
model_dir
,
_save_checkpoint
(
checkpoint
,
model_dir
,
checkpoint_name
.
format
(
step
=
current_step
))
checkpoint_name
.
format
(
step
=
current_step
))
last_save_checkpoint_step
=
current_step
if
test_step
:
if
test_step
:
eval_iterator
=
self
.
_get_input_iterator
(
eval_input_fn
,
strategy
)
eval_iterator
=
self
.
_get_input_iterator
(
eval_input_fn
,
strategy
)
...
...
official/vision/detection/main.py
View file @
1b77cd80
...
@@ -55,7 +55,10 @@ flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
...
@@ -55,7 +55,10 @@ flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
def
run_executor
(
params
,
train_input_fn
=
None
,
eval_input_fn
=
None
):
def
run_executor
(
params
,
train_input_fn
=
None
,
eval_input_fn
=
None
,
callbacks
=
None
):
"""Runs Retinanet model on distribution strategy defined by the user."""
"""Runs Retinanet model on distribution strategy defined by the user."""
model_builder
=
model_factory
.
model_generator
(
params
)
model_builder
=
model_factory
.
model_generator
(
params
)
...
@@ -92,6 +95,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
...
@@ -92,6 +95,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
iterations_per_loop
=
params
.
train
.
iterations_per_loop
,
iterations_per_loop
=
params
.
train
.
iterations_per_loop
,
total_steps
=
params
.
train
.
total_steps
,
total_steps
=
params
.
train
.
total_steps
,
init_checkpoint
=
model_builder
.
make_restore_checkpoint_fn
(),
init_checkpoint
=
model_builder
.
make_restore_checkpoint_fn
(),
custom_callbacks
=
callbacks
,
save_config
=
True
)
save_config
=
True
)
elif
FLAGS
.
mode
==
'eval'
:
elif
FLAGS
.
mode
==
'eval'
:
...
@@ -124,9 +128,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
...
@@ -124,9 +128,7 @@ def run_executor(params, train_input_fn=None, eval_input_fn=None):
raise
ValueError
(
'Mode not found: %s.'
%
FLAGS
.
mode
)
raise
ValueError
(
'Mode not found: %s.'
%
FLAGS
.
mode
)
def
main
(
argv
):
def
run
(
callbacks
=
None
):
del
argv
# Unused.
params
=
config_factory
.
config_generator
(
FLAGS
.
model
)
params
=
config_factory
.
config_generator
(
FLAGS
.
model
)
params
=
params_dict
.
override_params_dict
(
params
=
params_dict
.
override_params_dict
(
...
@@ -171,7 +173,16 @@ def main(argv):
...
@@ -171,7 +173,16 @@ def main(argv):
batch_size
=
params
.
eval
.
batch_size
,
batch_size
=
params
.
eval
.
batch_size
,
num_examples
=
params
.
eval
.
eval_samples
)
num_examples
=
params
.
eval
.
eval_samples
)
return
run_executor
(
return
run_executor
(
params
,
train_input_fn
=
train_input_fn
,
eval_input_fn
=
eval_input_fn
)
params
,
train_input_fn
=
train_input_fn
,
eval_input_fn
=
eval_input_fn
,
callbacks
=
callbacks
)
def
main
(
argv
):
del
argv
# Unused.
return
run
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment