Commit 6b695ca6 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Disabling Tensorboard profiling for NCF.

PiperOrigin-RevId: 371256980
parent e96371f0
......@@ -258,10 +258,9 @@ def run_ncf(_):
"val_HR_METRIC", desired_value=FLAGS.hr_threshold)
callbacks.append(early_stopping_callback)
(train_input_dataset, eval_input_dataset,
num_train_steps, num_eval_steps) = \
(ncf_input_pipeline.create_ncf_input_data(
params, producer, input_meta_data, strategy))
(train_input_dataset, eval_input_dataset, num_train_steps,
num_eval_steps) = ncf_input_pipeline.create_ncf_input_data(
params, producer, input_meta_data, strategy)
steps_per_epoch = None if generate_input_online else num_train_steps
with distribute_utils.get_strategy_scope(strategy):
......@@ -307,7 +306,8 @@ def run_ncf(_):
if not FLAGS.ml_perf:
# Create Tensorboard summary and checkpoint callbacks.
summary_dir = os.path.join(FLAGS.model_dir, "summaries")
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir)
summary_callback = tf.keras.callbacks.TensorBoard(
summary_dir, profile_batch=0)
checkpoint_path = os.path.join(FLAGS.model_dir, "checkpoint")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True)
......
......@@ -38,11 +38,13 @@ class NcfTest(tf.test.TestCase):
ncf_common.define_ncf_flags()
def setUp(self):
super().setUp()
self.top_k_old = rconst.TOP_K
self.num_eval_negatives_old = rconst.NUM_EVAL_NEGATIVES
rconst.NUM_EVAL_NEGATIVES = 2
def tearDown(self):
super().tearDown()
rconst.NUM_EVAL_NEGATIVES = self.num_eval_negatives_old
rconst.TOP_K = self.top_k_old
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment