Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b55c9da0
Commit
b55c9da0
authored
Apr 01, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Apr 01, 2020
Browse files
Fix lr in callback
PiperOrigin-RevId: 304250237
parent
f276d472
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
19 deletions
+21
-19
official/vision/image_classification/callbacks.py
official/vision/image_classification/callbacks.py
+21
-19
No files found.
official/vision/image_classification/callbacks.py
View file @
b55c9da0
...
...
@@ -42,19 +42,22 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks
=
[]
if
model_checkpoint
:
ckpt_full_path
=
os
.
path
.
join
(
model_dir
,
'model.ckpt-{epoch:04d}'
)
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
save_weights_only
=
True
,
verbose
=
1
))
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
save_weights_only
=
True
,
verbose
=
1
))
if
include_tensorboard
:
callbacks
.
append
(
CustomTensorBoard
(
log_dir
=
model_dir
,
track_lr
=
track_lr
,
initial_step
=
initial_step
,
write_images
=
write_model_weights
))
callbacks
.
append
(
CustomTensorBoard
(
log_dir
=
model_dir
,
track_lr
=
track_lr
,
initial_step
=
initial_step
,
write_images
=
write_model_weights
))
if
time_history
:
callbacks
.
append
(
keras_utils
.
TimeHistory
(
batch_size
,
log_steps
,
logdir
=
model_dir
if
include_tensorboard
else
None
))
callbacks
.
append
(
keras_utils
.
TimeHistory
(
batch_size
,
log_steps
,
logdir
=
model_dir
if
include_tensorboard
else
None
))
return
callbacks
...
...
@@ -74,13 +77,14 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
- Global learning rate
Attributes:
log_dir: the path of the directory where to save the log files to be
parsed
by TensorBoard.
log_dir: the path of the directory where to save the log files to be
parsed
by TensorBoard.
track_lr: `bool`, whether or not to track the global learning rate.
initial_step: the initial step, used for preemption recovery.
**kwargs: Additional arguments for backwards compatibility. Possible key
is
`period`.
**kwargs: Additional arguments for backwards compatibility. Possible key
is
`period`.
"""
# TODO(b/146499062): track params, flops, log lr, l2 loss,
# classification loss
...
...
@@ -130,10 +134,8 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def
_calculate_lr
(
self
)
->
int
:
"""Calculates the learning rate given the current step."""
lr
=
self
.
_get_base_optimizer
().
lr
if
callable
(
lr
):
lr
=
lr
(
self
.
step
)
return
get_scalar_from_tensor
(
lr
)
return
get_scalar_from_tensor
(
self
.
_get_base_optimizer
().
_decayed_lr
(
var_dtype
=
tf
.
float32
))
def
_get_base_optimizer
(
self
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
"""Get the base optimizer used by the current model."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment