Commit 7c83a9d7 authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 302921283
parent 9fc4fd08
...@@ -239,7 +239,10 @@ def get_callbacks(steps_per_epoch): ...@@ -239,7 +239,10 @@ def get_callbacks(steps_per_epoch):
"""Returns common callbacks.""" """Returns common callbacks."""
callbacks = [] callbacks = []
if FLAGS.enable_time_history: if FLAGS.enable_time_history:
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps) time_callback = keras_utils.TimeHistory(
FLAGS.batch_size,
FLAGS.log_steps,
FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks.append(time_callback) callbacks.append(time_callback)
if FLAGS.enable_tensorboard: if FLAGS.enable_tensorboard:
......
...@@ -246,6 +246,11 @@ class TransformerTask(object): ...@@ -246,6 +246,11 @@ class TransformerTask(object):
callbacks = self._create_callbacks(flags_obj.model_dir, 0, params) callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)
# Only TimeHistory callback is supported for CTL
if params["use_ctl"]:
callbacks = [cb for cb in callbacks
if isinstance(cb, keras_utils.TimeHistory)]
# TODO(b/139418525): Refactor the custom training loop logic. # TODO(b/139418525): Refactor the custom training loop logic.
@tf.function @tf.function
def train_steps(iterator, steps): def train_steps(iterator, steps):
...@@ -299,8 +304,13 @@ class TransformerTask(object): ...@@ -299,8 +304,13 @@ class TransformerTask(object):
if not self.use_tpu: if not self.use_tpu:
raise NotImplementedError( raise NotImplementedError(
"Custom training loop on GPUs is not implemented.") "Custom training loop on GPUs is not implemented.")
# Runs training steps. # Runs training steps.
with summary_writer.as_default(): with summary_writer.as_default():
for cb in callbacks:
cb.on_epoch_begin(current_iteration)
cb.on_batch_begin(0)
train_steps( train_steps(
train_ds_iterator, train_ds_iterator,
tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32)) tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
...@@ -309,10 +319,18 @@ class TransformerTask(object): ...@@ -309,10 +319,18 @@ class TransformerTask(object):
logging.info("Train Step: %d/%d / loss = %s", current_step, logging.info("Train Step: %d/%d / loss = %s", current_step,
flags_obj.train_steps, train_loss) flags_obj.train_steps, train_loss)
for cb in callbacks:
cb.on_batch_end(train_steps_per_eval - 1)
cb.on_epoch_end(current_iteration)
if params["enable_tensorboard"]: if params["enable_tensorboard"]:
for metric_obj in train_metrics: for metric_obj in train_metrics:
tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(), tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(),
current_step) current_step)
summary_writer.flush()
for cb in callbacks:
cb.on_train_end()
if flags_obj.enable_checkpointing: if flags_obj.enable_checkpointing:
# avoid check-pointing when running for benchmarking. # avoid check-pointing when running for benchmarking.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment