Commit 6e8f1284 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

add TPUStrategy to the support list of BackupAndRestore callback.

PiperOrigin-RevId: 332042245
parent 5dcfd2c5
...@@ -165,12 +165,15 @@ class CallbacksConfig(base_config.Config): ...@@ -165,12 +165,15 @@ class CallbacksConfig(base_config.Config):
Attributes: Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True. Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback. enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True. Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks. enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True. Defaults to True.
""" """
enable_checkpoint_and_export: bool = True enable_checkpoint_and_export: bool = True
enable_backup_and_restore: bool = False
enable_tensorboard: bool = True enable_tensorboard: bool = True
enable_time_history: bool = True enable_time_history: bool = True
......
...@@ -29,7 +29,8 @@ from official.modeling import optimization ...@@ -29,7 +29,8 @@ from official.modeling import optimization
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
def get_callbacks(model_checkpoint: bool = True, def get_callbacks(
model_checkpoint: bool = True,
include_tensorboard: bool = True, include_tensorboard: bool = True,
time_history: bool = True, time_history: bool = True,
track_lr: bool = True, track_lr: bool = True,
...@@ -38,7 +39,8 @@ def get_callbacks(model_checkpoint: bool = True, ...@@ -38,7 +39,8 @@ def get_callbacks(model_checkpoint: bool = True,
initial_step: int = 0, initial_step: int = 0,
batch_size: int = 0, batch_size: int = 0,
log_steps: int = 0, log_steps: int = 0,
model_dir: str = None) -> List[tf.keras.callbacks.Callback]: model_dir: str = None,
backup_and_restore: bool = False) -> List[tf.keras.callbacks.Callback]:
"""Get all callbacks.""" """Get all callbacks."""
model_dir = model_dir or '' model_dir = model_dir or ''
callbacks = [] callbacks = []
...@@ -47,6 +49,10 @@ def get_callbacks(model_checkpoint: bool = True, ...@@ -47,6 +49,10 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks.append( callbacks.append(
tf.keras.callbacks.ModelCheckpoint( tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1)) ckpt_full_path, save_weights_only=True, verbose=1))
if backup_and_restore:
backup_dir = os.path.join(model_dir, 'tmp')
callbacks.append(
tf.keras.callbacks.experimental.BackupAndRestore(backup_dir))
if include_tensorboard: if include_tensorboard:
callbacks.append( callbacks.append(
CustomTensorBoard( CustomTensorBoard(
......
...@@ -368,7 +368,8 @@ def train_and_eval( ...@@ -368,7 +368,8 @@ def train_and_eval(
initial_step=initial_epoch * train_steps, initial_step=initial_epoch * train_steps,
batch_size=train_builder.global_batch_size, batch_size=train_builder.global_batch_size,
log_steps=params.train.time_history.log_steps, log_steps=params.train.time_history.log_steps,
model_dir=params.model_dir) model_dir=params.model_dir,
backup_and_restore=params.train.callbacks.enable_backup_and_restore)
serialize_config(params=params, model_dir=params.model_dir) serialize_config(params=params, model_dir=params.model_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment