Commit b55c9da0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix lr in callback

PiperOrigin-RevId: 304250237
parent f276d472
...@@ -42,19 +42,22 @@ def get_callbacks(model_checkpoint: bool = True, ...@@ -42,19 +42,22 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks = [] callbacks = []
if model_checkpoint: if model_checkpoint:
ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}') ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
callbacks.append(tf.keras.callbacks.ModelCheckpoint( callbacks.append(
ckpt_full_path, save_weights_only=True, verbose=1)) tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1))
if include_tensorboard: if include_tensorboard:
callbacks.append(CustomTensorBoard( callbacks.append(
log_dir=model_dir, CustomTensorBoard(
track_lr=track_lr, log_dir=model_dir,
initial_step=initial_step, track_lr=track_lr,
write_images=write_model_weights)) initial_step=initial_step,
write_images=write_model_weights))
if time_history: if time_history:
callbacks.append(keras_utils.TimeHistory( callbacks.append(
batch_size, keras_utils.TimeHistory(
log_steps, batch_size,
logdir=model_dir if include_tensorboard else None)) log_steps,
logdir=model_dir if include_tensorboard else None))
return callbacks return callbacks
...@@ -74,13 +77,14 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -74,13 +77,14 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
- Global learning rate - Global learning rate
Attributes: Attributes:
log_dir: the path of the directory where to save the log files to be log_dir: the path of the directory where to save the log files to be parsed
parsed by TensorBoard. by TensorBoard.
track_lr: `bool`, whether or not to track the global learning rate. track_lr: `bool`, whether or not to track the global learning rate.
initial_step: the initial step, used for preemption recovery. initial_step: the initial step, used for preemption recovery.
**kwargs: Additional arguments for backwards compatibility. Possible key **kwargs: Additional arguments for backwards compatibility. Possible key is
is `period`. `period`.
""" """
# TODO(b/146499062): track params, flops, log lr, l2 loss, # TODO(b/146499062): track params, flops, log lr, l2 loss,
# classification loss # classification loss
...@@ -130,10 +134,8 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard): ...@@ -130,10 +134,8 @@ class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
def _calculate_lr(self) -> int: def _calculate_lr(self) -> int:
"""Calculates the learning rate given the current step.""" """Calculates the learning rate given the current step."""
lr = self._get_base_optimizer().lr return get_scalar_from_tensor(
if callable(lr): self._get_base_optimizer()._decayed_lr(var_dtype=tf.float32))
lr = lr(self.step)
return get_scalar_from_tensor(lr)
def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer: def _get_base_optimizer(self) -> tf.keras.optimizers.Optimizer:
"""Get the base optimizer used by the current model.""" """Get the base optimizer used by the current model."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment