Commit 6da061c0 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 316797555
parent e4f04456
...@@ -162,6 +162,21 @@ class CallbacksConfig(base_config.Config): ...@@ -162,6 +162,21 @@ class CallbacksConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class TrainerConfig(base_config.Config): class TrainerConfig(base_config.Config):
"""Configuration for trainer.
Attributes:
optimizer_config: optimizer config, it includes optimizer, learning rate,
and warmup schedule configs.
train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_intervals: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinetely.
"""
optimizer_config: OptimizationConfig = OptimizationConfig() optimizer_config: OptimizationConfig = OptimizationConfig()
train_tf_while_loop: bool = True train_tf_while_loop: bool = True
train_tf_function: bool = True train_tf_function: bool = True
...@@ -170,6 +185,7 @@ class TrainerConfig(base_config.Config): ...@@ -170,6 +185,7 @@ class TrainerConfig(base_config.Config):
summary_interval: int = 1000 summary_interval: int = 1000
checkpoint_interval: int = 1000 checkpoint_interval: int = 1000
max_to_keep: int = 5 max_to_keep: int = 5
continuous_eval_timeout: Optional[int] = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -218,7 +218,7 @@ def get_callbacks(): ...@@ -218,7 +218,7 @@ def get_callbacks():
time_callback = keras_utils.TimeHistory( time_callback = keras_utils.TimeHistory(
FLAGS.batch_size, FLAGS.batch_size,
FLAGS.log_steps, FLAGS.log_steps,
FLAGS.model_dir if FLAGS.enable_tensorboard else None) logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks.append(time_callback) callbacks.append(time_callback)
if FLAGS.enable_tensorboard: if FLAGS.enable_tensorboard:
......
...@@ -41,12 +41,13 @@ class BatchTimestamp(object): ...@@ -41,12 +41,13 @@ class BatchTimestamp(object):
class TimeHistory(tf.keras.callbacks.Callback): class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models.""" """Callback for Keras models."""
def __init__(self, batch_size, log_steps, logdir=None): def __init__(self, batch_size, log_steps, initial_step=0, logdir=None):
"""Callback for logging performance. """Callback for logging performance.
Args: Args:
batch_size: Total batch size. batch_size: Total batch size.
log_steps: Interval of steps between logging of batch level stats. log_steps: Interval of steps between logging of batch level stats.
initial_step: Optional, initial step.
logdir: Optional directory to write TensorBoard summaries. logdir: Optional directory to write TensorBoard summaries.
""" """
# TODO(wcromar): remove this parameter and rely on `logs` parameter of # TODO(wcromar): remove this parameter and rely on `logs` parameter of
...@@ -54,8 +55,8 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -54,8 +55,8 @@ class TimeHistory(tf.keras.callbacks.Callback):
self.batch_size = batch_size self.batch_size = batch_size
super(TimeHistory, self).__init__() super(TimeHistory, self).__init__()
self.log_steps = log_steps self.log_steps = log_steps
self.last_log_step = 0 self.last_log_step = initial_step
self.steps_before_epoch = 0 self.steps_before_epoch = initial_step
self.steps_in_epoch = 0 self.steps_in_epoch = 0
self.start_time = None self.start_time = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment