Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3254cabb
Unverified
Commit
3254cabb
authored
May 24, 2019
by
Toby Boyd
Committed by
GitHub
May 24, 2019
Browse files
Moved common keras code to utils. (#6859)
parent
b9cab01b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
51 deletions
+53
-51
official/resnet/keras/keras_common.py
official/resnet/keras/keras_common.py
+4
-50
official/utils/misc/keras_utils.py
official/utils/misc/keras_utils.py
+49
-1
No files found.
official/resnet/keras/keras_common.py
View file @
3254cabb
...
@@ -30,7 +30,6 @@ import tensorflow as tf
...
@@ -30,7 +30,6 @@ import tensorflow as tf
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
# pylint: disable=ungrouped-imports
# pylint: disable=ungrouped-imports
from
tensorflow.core.protobuf
import
rewriter_config_pb2
from
tensorflow.core.protobuf
import
rewriter_config_pb2
from
tensorflow.python.eager
import
profiler
from
tensorflow.python.keras.optimizer_v2
import
(
gradient_descent
as
from
tensorflow.python.keras.optimizer_v2
import
(
gradient_descent
as
gradient_descent_v2
)
gradient_descent_v2
)
...
@@ -146,29 +145,6 @@ class PiecewiseConstantDecayWithWarmup(
...
@@ -146,29 +145,6 @@ class PiecewiseConstantDecayWithWarmup(
}
}
class
ProfilerCallback
(
tf
.
keras
.
callbacks
.
Callback
):
"""Save profiles in specified step range to log directory."""
def
__init__
(
self
,
log_dir
,
start_step
,
stop_step
):
super
(
ProfilerCallback
,
self
).
__init__
()
self
.
log_dir
=
log_dir
self
.
start_step
=
start_step
self
.
stop_step
=
stop_step
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
if
batch
==
self
.
start_step
:
profiler
.
start
()
tf
.
compat
.
v1
.
logging
.
info
(
'Profiler started at Step %s'
,
self
.
start_step
)
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
if
batch
==
self
.
stop_step
:
results
=
profiler
.
stop
()
profiler
.
save
(
self
.
log_dir
,
results
)
tf
.
compat
.
v1
.
logging
.
info
(
'Profiler saved profiles for steps between %s and %s to %s'
,
self
.
start_step
,
self
.
stop_step
,
self
.
log_dir
)
def
get_config_proto_v1
():
def
get_config_proto_v1
():
"""Return config proto according to flag settings, or None to use default."""
"""Return config proto according to flag settings, or None to use default."""
config
=
None
config
=
None
...
@@ -250,37 +226,15 @@ def get_callbacks(learning_rate_schedule_fn, num_images):
...
@@ -250,37 +226,15 @@ def get_callbacks(learning_rate_schedule_fn, num_images):
callbacks
.
append
(
tensorboard_callback
)
callbacks
.
append
(
tensorboard_callback
)
if
FLAGS
.
profile_steps
:
if
FLAGS
.
profile_steps
:
profiler_callback
=
get_profiler_callback
()
profiler_callback
=
keras_utils
.
get_profiler_callback
(
FLAGS
.
model_dir
,
FLAGS
.
profile_steps
,
FLAGS
.
enable_tensorboard
)
callbacks
.
append
(
profiler_callback
)
callbacks
.
append
(
profiler_callback
)
return
callbacks
return
callbacks
def
get_profiler_callback
():
"""Validate profile_steps flag value and return profiler callback."""
profile_steps_error_message
=
(
'profile_steps must be a comma separated pair of positive integers, '
'specifying the first and last steps to be profiled.'
)
try
:
profile_steps
=
[
int
(
i
)
for
i
in
FLAGS
.
profile_steps
.
split
(
','
)]
except
ValueError
:
raise
ValueError
(
profile_steps_error_message
)
if
len
(
profile_steps
)
!=
2
:
raise
ValueError
(
profile_steps_error_message
)
start_step
,
stop_step
=
profile_steps
if
start_step
<
0
or
start_step
>
stop_step
:
raise
ValueError
(
profile_steps_error_message
)
if
FLAGS
.
enable_tensorboard
:
tf
.
compat
.
v1
.
logging
.
warn
(
'Both TensorBoard and profiler callbacks are used. Note that the '
'TensorBoard callback profiles the 2nd step (unless otherwise '
'specified). Please make sure the steps profiled by the two callbacks '
'do not overlap.'
)
return
ProfilerCallback
(
FLAGS
.
model_dir
,
start_step
,
stop_step
)
def
build_stats
(
history
,
eval_output
,
callbacks
):
def
build_stats
(
history
,
eval_output
,
callbacks
):
"""Normalizes and returns dictionary of stats.
"""Normalizes and returns dictionary of stats.
...
...
official/utils/misc/keras_utils.py
View file @
3254cabb
...
@@ -20,8 +20,8 @@ from __future__ import print_function
...
@@ -20,8 +20,8 @@ from __future__ import print_function
import
time
import
time
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.eager
import
profiler
class
BatchTimestamp
(
object
):
class
BatchTimestamp
(
object
):
...
@@ -80,3 +80,51 @@ class TimeHistory(tf.keras.callbacks.Callback):
...
@@ -80,3 +80,51 @@ class TimeHistory(tf.keras.callbacks.Callback):
"BenchmarkMetric: {'num_batches':%d, 'time_taken': %f,"
"BenchmarkMetric: {'num_batches':%d, 'time_taken': %f,"
"'examples_per_second': %f}"
%
"'examples_per_second': %f}"
%
(
batch
,
elapsed_time
,
examples_per_second
))
(
batch
,
elapsed_time
,
examples_per_second
))
def
get_profiler_callback
(
model_dir
,
profile_steps
,
enable_tensorboard
):
"""Validate profile_steps flag value and return profiler callback."""
profile_steps_error_message
=
(
'profile_steps must be a comma separated pair of positive integers, '
'specifying the first and last steps to be profiled.'
)
try
:
profile_steps
=
[
int
(
i
)
for
i
in
profile_steps
.
split
(
','
)]
except
ValueError
:
raise
ValueError
(
profile_steps_error_message
)
if
len
(
profile_steps
)
!=
2
:
raise
ValueError
(
profile_steps_error_message
)
start_step
,
stop_step
=
profile_steps
if
start_step
<
0
or
start_step
>
stop_step
:
raise
ValueError
(
profile_steps_error_message
)
if
enable_tensorboard
:
tf
.
compat
.
v1
.
logging
.
warn
(
'Both TensorBoard and profiler callbacks are used. Note that the '
'TensorBoard callback profiles the 2nd step (unless otherwise '
'specified). Please make sure the steps profiled by the two callbacks '
'do not overlap.'
)
return
ProfilerCallback
(
model_dir
,
start_step
,
stop_step
)
class
ProfilerCallback
(
tf
.
keras
.
callbacks
.
Callback
):
"""Save profiles in specified step range to log directory."""
def
__init__
(
self
,
log_dir
,
start_step
,
stop_step
):
super
(
ProfilerCallback
,
self
).
__init__
()
self
.
log_dir
=
log_dir
self
.
start_step
=
start_step
self
.
stop_step
=
stop_step
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
if
batch
==
self
.
start_step
:
profiler
.
start
()
tf
.
compat
.
v1
.
logging
.
info
(
'Profiler started at Step %s'
,
self
.
start_step
)
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
if
batch
==
self
.
stop_step
:
results
=
profiler
.
stop
()
profiler
.
save
(
self
.
log_dir
,
results
)
tf
.
compat
.
v1
.
logging
.
info
(
'Profiler saved profiles for steps between %s and %s to %s'
,
self
.
start_step
,
self
.
stop_step
,
self
.
log_dir
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment