Commit 6e8f1284 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

add TPUStrategy to the support list of BackupAndRestore callback.

PiperOrigin-RevId: 332042245
parent 5dcfd2c5
......@@ -165,12 +165,15 @@ class CallbacksConfig(base_config.Config):
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_backup_and_restore: bool = False
enable_tensorboard: bool = True
enable_time_history: bool = True
......
......@@ -29,16 +29,18 @@ from official.modeling import optimization
from official.utils.misc import keras_utils
def get_callbacks(model_checkpoint: bool = True,
include_tensorboard: bool = True,
time_history: bool = True,
track_lr: bool = True,
write_model_weights: bool = True,
apply_moving_average: bool = False,
initial_step: int = 0,
batch_size: int = 0,
log_steps: int = 0,
model_dir: str = None) -> List[tf.keras.callbacks.Callback]:
def get_callbacks(
model_checkpoint: bool = True,
include_tensorboard: bool = True,
time_history: bool = True,
track_lr: bool = True,
write_model_weights: bool = True,
apply_moving_average: bool = False,
initial_step: int = 0,
batch_size: int = 0,
log_steps: int = 0,
model_dir: str = None,
backup_and_restore: bool = False) -> List[tf.keras.callbacks.Callback]:
"""Get all callbacks."""
model_dir = model_dir or ''
callbacks = []
......@@ -47,6 +49,10 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1))
if backup_and_restore:
backup_dir = os.path.join(model_dir, 'tmp')
callbacks.append(
tf.keras.callbacks.experimental.BackupAndRestore(backup_dir))
if include_tensorboard:
callbacks.append(
CustomTensorBoard(
......
......@@ -368,7 +368,8 @@ def train_and_eval(
initial_step=initial_epoch * train_steps,
batch_size=train_builder.global_batch_size,
log_steps=params.train.time_history.log_steps,
model_dir=params.model_dir)
model_dir=params.model_dir,
backup_and_restore=params.train.callbacks.enable_backup_and_restore)
serialize_config(params=params, model_dir=params.model_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment