Commit 30165f86 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 282650416
parent 0b579232
......@@ -147,7 +147,7 @@ def get_v1_distribution_strategy(params):
def define_ncf_flags():
"""Add flags for running ncf_main."""
# Add common flags
flags_core.define_base(clean=True, train_epochs=True,
flags_core.define_base(model_dir=True, clean=True, train_epochs=True,
epochs_between_evals=True, export_dir=False,
run_eagerly=True, stop_threshold=True, num_gpu=True,
hooks=True, distribution_strategy=True)
......
......@@ -206,8 +206,7 @@ def run_ncf(_):
print("Setting tf seed")
tf.random.set_seed(FLAGS.seed)
params = ncf_common.parse_flags(FLAGS)
model_helpers.apply_clean(flags.FLAGS)
model_helpers.apply_clean(FLAGS)
if FLAGS.dtype == "fp16" and FLAGS.fp16_implementation == "keras":
policy = tf.keras.mixed_precision.experimental.Policy(
......@@ -219,6 +218,8 @@ def run_ncf(_):
distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus,
tpu_address=FLAGS.tpu)
params = ncf_common.parse_flags(FLAGS)
params["distribute_strategy"] = strategy
if not keras_utils.is_v2_0() and strategy is not None:
......@@ -307,8 +308,17 @@ def run_ncf(_):
run_eagerly=FLAGS.run_eagerly,
experimental_run_tf_function=FLAGS.force_v2_in_keras_compile)
else:
keras_model.compile(
optimizer=optimizer, run_eagerly=FLAGS.run_eagerly)
keras_model.compile(optimizer=optimizer, run_eagerly=FLAGS.run_eagerly)
if not FLAGS.ml_perf:
# Create Tensorboard summary and checkpoint callbacks.
summary_dir = os.path.join(FLAGS.model_dir, "summaries")
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
checkpoint_path = os.path.join(FLAGS.model_dir, "checkpoint")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True)
callbacks += [summary_callback, checkpoint_callback]
history = keras_model.fit(
train_input_dataset,
......@@ -438,6 +448,16 @@ def run_ncf_custom_training(params,
for callback in callbacks:
callback.on_train_begin()
# Not writing tensorboard summaries if running in MLPerf.
if FLAGS.ml_perf:
eval_summary_writer, train_summary_writer = None, None
else:
summary_dir = os.path.join(FLAGS.model_dir, "summaries")
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, "eval"))
train_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, "train"))
train_loss = 0
for epoch in range(FLAGS.train_epochs):
for cb in callbacks:
......@@ -460,6 +480,12 @@ def run_ncf_custom_training(params,
train_loss += train_step(train_input_iterator)
# Write train loss once in every 100 steps.
if train_summary_writer and step % 100 == 0:
with train_summary_writer.as_default():
tf.summary.scalar("training_loss", train_loss/(step + 1),
step=current_step)
for c in callbacks:
c.on_batch_end(current_step)
......@@ -476,7 +502,11 @@ def run_ncf_custom_training(params,
hr_sum += step_hr_sum
hr_count += step_hr_count
logging.info("Done eval epoch %s, hr=%s.", epoch + 1, hr_sum / hr_count)
logging.info("Done eval epoch %s, hit_rate=%s.", epoch + 1,
hr_sum / hr_count)
if eval_summary_writer:
with eval_summary_writer.as_default():
tf.summary.scalar("hit_rate", hr_sum / hr_count, step=current_step)
if (FLAGS.early_stopping and
float(hr_sum / hr_count) > params["hr_threshold"]):
......@@ -485,6 +515,13 @@ def run_ncf_custom_training(params,
for c in callbacks:
c.on_train_end()
# Saving the model at the end of training.
if not FLAGS.ml_perf:
checkpoint = tf.train.Checkpoint(model=keras_model, optimizer=optimizer)
checkpoint_path = os.path.join(FLAGS.model_dir, "ctl_checkpoint")
checkpoint.save(checkpoint_path)
logging.info("Saving model as TF checkpoint: %s", checkpoint_path)
return train_loss, [None, hr_sum / hr_count]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment