Commit 00c7ea79 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 327984543
parent 19e8d0a0
...@@ -213,10 +213,10 @@ class TransformerTask(object): ...@@ -213,10 +213,10 @@ class TransformerTask(object):
train_loss_metric = tf.keras.metrics.Mean( train_loss_metric = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32) "training_loss", dtype=tf.float32)
if params["enable_tensorboard"]: if params["enable_tensorboard"]:
summary_writer = tf.compat.v2.summary.create_file_writer( summary_writer = tf.summary.create_file_writer(
flags_obj.model_dir) os.path.join(flags_obj.model_dir, "summary"))
else: else:
summary_writer = tf.compat.v2.summary.create_noop_writer() summary_writer = tf.summary.create_noop_writer()
train_metrics = [train_loss_metric] train_metrics = [train_loss_metric]
if params["enable_metrics_in_training"]: if params["enable_metrics_in_training"]:
train_metrics = train_metrics + model.metrics train_metrics = train_metrics + model.metrics
...@@ -322,8 +322,8 @@ class TransformerTask(object): ...@@ -322,8 +322,8 @@ class TransformerTask(object):
if params["enable_tensorboard"]: if params["enable_tensorboard"]:
for metric_obj in train_metrics: for metric_obj in train_metrics:
tf.compat.v2.summary.scalar(metric_obj.name, metric_obj.result(), tf.summary.scalar(metric_obj.name, metric_obj.result(),
current_step) current_step)
summary_writer.flush() summary_writer.flush()
for cb in callbacks: for cb in callbacks:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment