"benchmark/git@developer.sourcefind.cn:change/sglang.git" did not exist on "22085081bb247cc57fe971c3d72eb66f053d77b6"
Commit 7c83a9d7 authored by Will Cromar's avatar Will Cromar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 302921283
parent 9fc4fd08
......@@ -239,7 +239,10 @@ def get_callbacks(steps_per_epoch):
"""Returns common callbacks."""
callbacks = []
if FLAGS.enable_time_history:
time_callback = keras_utils.TimeHistory(FLAGS.batch_size, FLAGS.log_steps)
time_callback = keras_utils.TimeHistory(
FLAGS.batch_size,
FLAGS.log_steps,
FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks.append(time_callback)
if FLAGS.enable_tensorboard:
......
......@@ -246,6 +246,11 @@ class TransformerTask(object):
callbacks = self._create_callbacks(flags_obj.model_dir, 0, params)
# Only TimeHistory callback is supported for CTL
if params["use_ctl"]:
callbacks = [cb for cb in callbacks
if isinstance(cb, keras_utils.TimeHistory)]
# TODO(b/139418525): Refactor the custom training loop logic.
@tf.function
def train_steps(iterator, steps):
......@@ -299,8 +304,13 @@ class TransformerTask(object):
if not self.use_tpu:
raise NotImplementedError(
"Custom training loop on GPUs is not implemented.")
# Runs training steps.
with summary_writer.as_default():
for cb in callbacks:
cb.on_epoch_begin(current_iteration)
cb.on_batch_begin(0)
train_steps(
train_ds_iterator,
tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
......@@ -309,10 +319,18 @@ class TransformerTask(object):
logging.info("Train Step: %d/%d / loss = %s", current_step,
flags_obj.train_steps, train_loss)
for cb in callbacks:
cb.on_batch_end(train_steps_per_eval - 1)
cb.on_epoch_end(current_iteration)
if params["enable_tensorboard"]:
for metric_obj in train_metrics:
tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(),
current_step)
summary_writer.flush()
for cb in callbacks:
cb.on_train_end()
if flags_obj.enable_checkpointing:
# avoid check-pointing when running for benchmarking.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment