Commit 9a833e2c authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 274386468
parent d0600d40
......@@ -129,18 +129,22 @@ def train(
# pylint: disable=protected-access
train_iterator = data_utils._get_input_iterator(train_input_fn, strategy)
# pylint: enable=protected-access
train_summary_writer = None
eval_summary_writer = None
if not tf.io.gfile.exists(model_dir):
tf.io.gfile.mkdir(model_dir)
# Create summary writers
summary_dir = os.path.join(model_dir, "summaries")
if not tf.io.gfile.exists(summary_dir):
tf.io.gfile.mkdir(summary_dir)
train_summary_writer = None
eval_summary_writer = None
if test_input_fn:
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, "summaries/eval"))
os.path.join(summary_dir, "eval"))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
# Only writes summary when the stats are collected sufficiently over
# enough steps.
train_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, "summaries/train"))
os.path.join(summary_dir, "train"))
with strategy.scope():
model = model_fn()
......@@ -316,6 +320,6 @@ def train(
# eval_metric is supposed to be a float.
training_summary["eval_metrics"] = eval_metric
model_training_utils.write_txt_summary(training_summary, model_dir)
model_training_utils.write_txt_summary(training_summary, summary_dir)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment