Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b60dc237
Commit
b60dc237
authored
Feb 21, 2020
by
Will Cromar
Committed by
A. Unique TensorFlower
Feb 21, 2020
Browse files
Write examples/second and steps/second summaries in TimeHistory callback.
PiperOrigin-RevId: 296507807
parent
706a0bd9
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
87 additions
and
47 deletions
+87
-47
official/staging/training/utils.py
official/staging/training/utils.py
+10
-2
official/utils/misc/keras_utils.py
official/utils/misc/keras_utils.py
+59
-16
official/vision/image_classification/common.py
official/vision/image_classification/common.py
+7
-6
official/vision/image_classification/resnet_ctl_imagenet_main.py
...l/vision/image_classification/resnet_ctl_imagenet_main.py
+6
-11
official/vision/image_classification/resnet_runnable.py
official/vision/image_classification/resnet_runnable.py
+5
-12
No files found.
official/staging/training/utils.py
View file @
b60dc237
...
@@ -298,13 +298,16 @@ class EpochHelper(object):
...
@@ -298,13 +298,16 @@ class EpochHelper(object):
self
.
_epoch_steps
=
epoch_steps
self
.
_epoch_steps
=
epoch_steps
self
.
_global_step
=
global_step
self
.
_global_step
=
global_step
self
.
_current_epoch
=
None
self
.
_current_epoch
=
None
self
.
_epoch_start_step
=
None
self
.
_in_epoch
=
False
self
.
_in_epoch
=
False
def
epoch_begin
(
self
):
def
epoch_begin
(
self
):
"""Returns whether a new epoch should begin."""
"""Returns whether a new epoch should begin."""
if
self
.
_in_epoch
:
if
self
.
_in_epoch
:
return
False
return
False
self
.
_current_epoch
=
self
.
_global_step
.
numpy
()
/
self
.
_epoch_steps
current_step
=
self
.
_global_step
.
numpy
()
self
.
_epoch_start_step
=
current_step
self
.
_current_epoch
=
current_step
//
self
.
_epoch_steps
self
.
_in_epoch
=
True
self
.
_in_epoch
=
True
return
True
return
True
...
@@ -313,13 +316,18 @@ class EpochHelper(object):
...
@@ -313,13 +316,18 @@ class EpochHelper(object):
if
not
self
.
_in_epoch
:
if
not
self
.
_in_epoch
:
raise
ValueError
(
"`epoch_end` can only be called inside an epoch"
)
raise
ValueError
(
"`epoch_end` can only be called inside an epoch"
)
current_step
=
self
.
_global_step
.
numpy
()
current_step
=
self
.
_global_step
.
numpy
()
epoch
=
current_step
/
self
.
_epoch_steps
epoch
=
current_step
/
/
self
.
_epoch_steps
if
epoch
>
self
.
_current_epoch
:
if
epoch
>
self
.
_current_epoch
:
self
.
_in_epoch
=
False
self
.
_in_epoch
=
False
return
True
return
True
return
False
return
False
@
property
def
batch_index
(
self
):
"""Index of the next batch within the current epoch."""
return
self
.
_global_step
.
numpy
()
-
self
.
_epoch_start_step
@
property
@
property
def
current_epoch
(
self
):
def
current_epoch
(
self
):
return
self
.
_current_epoch
return
self
.
_current_epoch
official/utils/misc/keras_utils.py
View file @
b60dc237
...
@@ -44,17 +44,28 @@ class BatchTimestamp(object):
...
@@ -44,17 +44,28 @@ class BatchTimestamp(object):
class
TimeHistory
(
tf
.
keras
.
callbacks
.
Callback
):
class
TimeHistory
(
tf
.
keras
.
callbacks
.
Callback
):
"""Callback for Keras models."""
"""Callback for Keras models."""
def
__init__
(
self
,
batch_size
,
log_steps
):
def
__init__
(
self
,
batch_size
,
log_steps
,
logdir
=
None
):
"""Callback for logging performance.
"""Callback for logging performance.
Args:
Args:
batch_size: Total batch size.
batch_size: Total batch size.
log_steps: Interval of steps between logging of batch level stats.
log_steps: Interval of steps between logging of batch level stats.
logdir: Optional directory to write TensorBoard summaries.
"""
"""
# TODO(wcromar): remove this parameter and rely on `logs` parameter of
# on_train_batch_end()
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
super
(
TimeHistory
,
self
).
__init__
()
super
(
TimeHistory
,
self
).
__init__
()
self
.
log_steps
=
log_steps
self
.
log_steps
=
log_steps
self
.
global_steps
=
0
self
.
last_log_step
=
0
self
.
steps_before_epoch
=
0
self
.
steps_in_epoch
=
0
self
.
start_time
=
None
if
logdir
:
self
.
summary_writer
=
tf
.
summary
.
create_file_writer
(
logdir
)
else
:
self
.
summary_writer
=
None
# Logs start of step 1 then end of each step based on log_steps interval.
# Logs start of step 1 then end of each step based on log_steps interval.
self
.
timestamp_log
=
[]
self
.
timestamp_log
=
[]
...
@@ -62,38 +73,70 @@ class TimeHistory(tf.keras.callbacks.Callback):
...
@@ -62,38 +73,70 @@ class TimeHistory(tf.keras.callbacks.Callback):
# Records the time each epoch takes to run from start to finish of epoch.
# Records the time each epoch takes to run from start to finish of epoch.
self
.
epoch_runtime_log
=
[]
self
.
epoch_runtime_log
=
[]
@
property
def
global_steps
(
self
):
"""The current 1-indexed global step."""
return
self
.
steps_before_epoch
+
self
.
steps_in_epoch
@
property
def
average_steps_per_second
(
self
):
"""The average training steps per second across all epochs."""
return
self
.
global_steps
/
sum
(
self
.
epoch_runtime_log
)
@
property
def
average_examples_per_second
(
self
):
"""The average number of training examples per second across all epochs."""
return
self
.
average_steps_per_second
*
self
.
batch_size
def
on_train_end
(
self
,
logs
=
None
):
def
on_train_end
(
self
,
logs
=
None
):
self
.
train_finish_time
=
time
.
time
()
self
.
train_finish_time
=
time
.
time
()
if
self
.
summary_writer
:
self
.
summary_writer
.
flush
()
def
on_epoch_begin
(
self
,
epoch
,
logs
=
None
):
def
on_epoch_begin
(
self
,
epoch
,
logs
=
None
):
self
.
epoch_start
=
time
.
time
()
self
.
epoch_start
=
time
.
time
()
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
self
.
global_steps
+=
1
if
not
self
.
start_time
:
if
self
.
global_steps
==
1
:
self
.
start_time
=
time
.
time
()
self
.
start_time
=
time
.
time
()
# Record the timestamp of the first global step
if
not
self
.
timestamp_log
:
self
.
timestamp_log
.
append
(
BatchTimestamp
(
self
.
global_steps
,
self
.
timestamp_log
.
append
(
BatchTimestamp
(
self
.
global_steps
,
self
.
start_time
))
self
.
start_time
))
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
"""Records elapse time of the batch and calculates examples per second."""
"""Records elapse time of the batch and calculates examples per second."""
if
self
.
global_steps
%
self
.
log_steps
==
0
:
self
.
steps_in_epoch
=
batch
+
1
timestamp
=
time
.
time
()
steps_since_last_log
=
self
.
global_steps
-
self
.
last_log_step
elapsed_time
=
timestamp
-
self
.
start_time
if
steps_since_last_log
>=
self
.
log_steps
:
examples_per_second
=
(
self
.
batch_size
*
self
.
log_steps
)
/
elapsed_time
now
=
time
.
time
()
self
.
timestamp_log
.
append
(
BatchTimestamp
(
self
.
global_steps
,
timestamp
))
elapsed_time
=
now
-
self
.
start_time
steps_per_second
=
steps_since_last_log
/
elapsed_time
examples_per_second
=
steps_per_second
*
self
.
batch_size
self
.
timestamp_log
.
append
(
BatchTimestamp
(
self
.
global_steps
,
now
))
logging
.
info
(
logging
.
info
(
"BenchmarkMetric: {'global step':%d, 'time_taken': %f,"
"TimeHistory: %.2f examples/second between steps %d and %d"
,
"'examples_per_second': %f}"
,
examples_per_second
,
self
.
last_log_step
,
self
.
global_steps
)
self
.
global_steps
,
elapsed_time
,
examples_per_second
)
self
.
start_time
=
timestamp
if
self
.
summary_writer
:
with
self
.
summary_writer
.
as_default
():
tf
.
summary
.
scalar
(
'global_step/sec'
,
steps_per_second
,
self
.
global_steps
)
tf
.
summary
.
scalar
(
'examples/sec'
,
examples_per_second
,
self
.
global_steps
)
self
.
last_log_step
=
self
.
global_steps
self
.
start_time
=
None
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
epoch_run_time
=
time
.
time
()
-
self
.
epoch_start
epoch_run_time
=
time
.
time
()
-
self
.
epoch_start
self
.
epoch_runtime_log
.
append
(
epoch_run_time
)
self
.
epoch_runtime_log
.
append
(
epoch_run_time
)
logging
.
info
(
"BenchmarkMetric: {'epoch':%d, 'time_taken': %f}"
,
self
.
steps_before_epoch
+=
self
.
steps_in_epoch
epoch
,
epoch_run_time
)
self
.
steps_in_epoch
=
0
def
get_profiler_callback
(
model_dir
,
profile_steps
,
enable_tensorboard
,
def
get_profiler_callback
(
model_dir
,
profile_steps
,
enable_tensorboard
,
...
...
official/vision/image_classification/common.py
View file @
b60dc237
...
@@ -188,7 +188,10 @@ def get_callbacks(
...
@@ -188,7 +188,10 @@ def get_callbacks(
enable_checkpoint_and_export
=
False
,
enable_checkpoint_and_export
=
False
,
model_dir
=
None
):
model_dir
=
None
):
"""Returns common callbacks."""
"""Returns common callbacks."""
time_callback
=
keras_utils
.
TimeHistory
(
FLAGS
.
batch_size
,
FLAGS
.
log_steps
)
time_callback
=
keras_utils
.
TimeHistory
(
FLAGS
.
batch_size
,
FLAGS
.
log_steps
,
logdir
=
FLAGS
.
model_dir
if
FLAGS
.
enable_tensorboard
else
None
)
callbacks
=
[
time_callback
]
callbacks
=
[
time_callback
]
if
not
FLAGS
.
use_tensor_lr
and
learning_rate_schedule_fn
:
if
not
FLAGS
.
use_tensor_lr
and
learning_rate_schedule_fn
:
...
@@ -265,11 +268,9 @@ def build_stats(history, eval_output, callbacks):
...
@@ -265,11 +268,9 @@ def build_stats(history, eval_output, callbacks):
timestamp_log
=
callback
.
timestamp_log
timestamp_log
=
callback
.
timestamp_log
stats
[
'step_timestamp_log'
]
=
timestamp_log
stats
[
'step_timestamp_log'
]
=
timestamp_log
stats
[
'train_finish_time'
]
=
callback
.
train_finish_time
stats
[
'train_finish_time'
]
=
callback
.
train_finish_time
if
len
(
timestamp_log
)
>
1
:
if
callback
.
epoch_runtime_log
:
stats
[
'avg_exp_per_second'
]
=
(
stats
[
'avg_exp_per_second'
]
=
callback
.
average_examples_per_second
callback
.
batch_size
*
callback
.
log_steps
*
(
len
(
callback
.
timestamp_log
)
-
1
)
/
(
timestamp_log
[
-
1
].
timestamp
-
timestamp_log
[
0
].
timestamp
))
return
stats
return
stats
...
...
official/vision/image_classification/resnet_ctl_imagenet_main.py
View file @
b60dc237
...
@@ -64,15 +64,8 @@ def build_stats(runnable, time_callback):
...
@@ -64,15 +64,8 @@ def build_stats(runnable, time_callback):
timestamp_log
=
time_callback
.
timestamp_log
timestamp_log
=
time_callback
.
timestamp_log
stats
[
'step_timestamp_log'
]
=
timestamp_log
stats
[
'step_timestamp_log'
]
=
timestamp_log
stats
[
'train_finish_time'
]
=
time_callback
.
train_finish_time
stats
[
'train_finish_time'
]
=
time_callback
.
train_finish_time
if
len
(
timestamp_log
)
>
1
:
if
time_callback
.
epoch_runtime_log
:
stats
[
'avg_exp_per_second'
]
=
(
stats
[
'avg_exp_per_second'
]
=
time_callback
.
average_examples_per_second
time_callback
.
batch_size
*
time_callback
.
log_steps
*
(
len
(
time_callback
.
timestamp_log
)
-
1
)
/
(
timestamp_log
[
-
1
].
timestamp
-
timestamp_log
[
0
].
timestamp
))
avg_exp_per_second
=
tf
.
reduce_mean
(
runnable
.
examples_per_second_history
).
numpy
(),
stats
[
'avg_exp_per_second'
]
=
avg_exp_per_second
return
stats
return
stats
...
@@ -154,8 +147,10 @@ def run(flags_obj):
...
@@ -154,8 +147,10 @@ def run(flags_obj):
'total steps: %d; Eval %d steps'
,
train_epochs
,
per_epoch_steps
,
'total steps: %d; Eval %d steps'
,
train_epochs
,
per_epoch_steps
,
train_epochs
*
per_epoch_steps
,
eval_steps
)
train_epochs
*
per_epoch_steps
,
eval_steps
)
time_callback
=
keras_utils
.
TimeHistory
(
flags_obj
.
batch_size
,
time_callback
=
keras_utils
.
TimeHistory
(
flags_obj
.
log_steps
)
flags_obj
.
batch_size
,
flags_obj
.
log_steps
,
logdir
=
flags_obj
.
model_dir
if
flags_obj
.
enable_tensorboard
else
None
)
with
distribution_utils
.
get_strategy_scope
(
strategy
):
with
distribution_utils
.
get_strategy_scope
(
strategy
):
runnable
=
resnet_runnable
.
ResnetRunnable
(
flags_obj
,
time_callback
,
runnable
=
resnet_runnable
.
ResnetRunnable
(
flags_obj
,
time_callback
,
per_epoch_steps
)
per_epoch_steps
)
...
...
official/vision/image_classification/resnet_runnable.py
View file @
b60dc237
...
@@ -114,7 +114,6 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
...
@@ -114,7 +114,6 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
# Handling epochs.
# Handling epochs.
self
.
epoch_steps
=
epoch_steps
self
.
epoch_steps
=
epoch_steps
self
.
epoch_helper
=
utils
.
EpochHelper
(
epoch_steps
,
self
.
global_step
)
self
.
epoch_helper
=
utils
.
EpochHelper
(
epoch_steps
,
self
.
global_step
)
self
.
examples_per_second_history
=
[]
def
build_train_dataset
(
self
):
def
build_train_dataset
(
self
):
"""See base class."""
"""See base class."""
...
@@ -147,8 +146,8 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
...
@@ -147,8 +146,8 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
self
.
train_loss
.
reset_states
()
self
.
train_loss
.
reset_states
()
self
.
train_accuracy
.
reset_states
()
self
.
train_accuracy
.
reset_states
()
self
.
time_callback
.
on_batch_begin
(
self
.
global_step
)
self
.
_epoch_begin
()
self
.
_epoch_begin
()
self
.
time_callback
.
on_batch_begin
(
self
.
epoch_helper
.
batch_index
)
def
train_step
(
self
,
iterator
):
def
train_step
(
self
,
iterator
):
"""See base class."""
"""See base class."""
...
@@ -194,12 +193,13 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
...
@@ -194,12 +193,13 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
def
train_loop_end
(
self
):
def
train_loop_end
(
self
):
"""See base class."""
"""See base class."""
self
.
time_callback
.
on_batch_end
(
self
.
global_step
)
metrics
=
{
self
.
_epoch_end
()
return
{
'train_loss'
:
self
.
train_loss
.
result
(),
'train_loss'
:
self
.
train_loss
.
result
(),
'train_accuracy'
:
self
.
train_accuracy
.
result
(),
'train_accuracy'
:
self
.
train_accuracy
.
result
(),
}
}
self
.
time_callback
.
on_batch_end
(
self
.
epoch_helper
.
batch_index
-
1
)
self
.
_epoch_end
()
return
metrics
def
eval_begin
(
self
):
def
eval_begin
(
self
):
"""See base class."""
"""See base class."""
...
@@ -234,10 +234,3 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
...
@@ -234,10 +234,3 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
def
_epoch_end
(
self
):
def
_epoch_end
(
self
):
if
self
.
epoch_helper
.
epoch_end
():
if
self
.
epoch_helper
.
epoch_end
():
self
.
time_callback
.
on_epoch_end
(
self
.
epoch_helper
.
current_epoch
)
self
.
time_callback
.
on_epoch_end
(
self
.
epoch_helper
.
current_epoch
)
epoch_time
=
self
.
time_callback
.
epoch_runtime_log
[
-
1
]
steps_per_second
=
self
.
epoch_steps
/
epoch_time
examples_per_second
=
steps_per_second
*
self
.
flags_obj
.
batch_size
self
.
examples_per_second_history
.
append
(
examples_per_second
)
tf
.
summary
.
scalar
(
'global_step/sec'
,
steps_per_second
)
tf
.
summary
.
scalar
(
'examples/sec'
,
examples_per_second
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment