Commit b55c9da0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix lr in callback

PiperOrigin-RevId: 304250237
parent f276d472
......@@ -42,19 +42,22 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks = []
if model_checkpoint:
ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
callbacks.append(tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1))
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1))
if include_tensorboard:
callbacks.append(CustomTensorBoard(
log_dir=model_dir,
track_lr=track_lr,
initial_step=initial_step,
write_images=write_model_weights))
callbacks.append(
CustomTensorBoard(
log_dir=model_dir,
track_lr=track_lr,
initial_step=initial_step,
write_images=write_model_weights))
if time_history:
callbacks.append(keras_utils.TimeHistory(
batch_size,
log_steps,
logdir=model_dir if include_tensorboard else None))
callbacks.append(
keras_utils.TimeHistory(
batch_size,
log_steps,
logdir=model_dir if include_tensorboard else None))
return callbacks
......@@ -74,13 +77,14 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
- Global learning rate
Attributes:
log_dir: the path of the directory where to save the log files to be
parsed by TensorBoard.
log_dir: the path of the directory where to save the log files to be parsed
by TensorBoard.
track_lr: `bool`, whether or not to track the global learning rate.
initial_step: the initial step, used for preemption recovery.
**kwargs: Additional arguments for backwards compatibility. Possible key
is `period`.
**kwargs: Additional arguments for backwards compatibility. Possible key is
`period`.
"""
# TODO(b/146499062): track params, flops, log lr, l2 loss,
# classification loss
......@@ -130,10 +134,8 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def _calculate_lr(self) -> int:
"""Calculates the learning rate given the current step."""
lr = self._get_base_optimizer().lr
if callable(lr):
lr = lr(self.step)
return get_scalar_from_tensor(lr)
return get_scalar_from_tensor(
self._get_base_optimizer()._decayed_lr(var_dtype=tf.float32))
def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer:
"""Get the base optimizer used by the current model."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment