Commit eff1856a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Disabling Tensorboard profiling for NCF.

PiperOrigin-RevId: 371256980
parent 4dac0570
...@@ -258,10 +258,9 @@ def run_ncf(_): ...@@ -258,10 +258,9 @@ def run_ncf(_):
"val_HR_METRIC", desired_value=FLAGS.hr_threshold) "val_HR_METRIC", desired_value=FLAGS.hr_threshold)
callbacks.append(early_stopping_callback) callbacks.append(early_stopping_callback)
(train_input_dataset, eval_input_dataset, (train_input_dataset, eval_input_dataset, num_train_steps,
num_train_steps, num_eval_steps) = \ num_eval_steps) = ncf_input_pipeline.create_ncf_input_data(
(ncf_input_pipeline.create_ncf_input_data( params, producer, input_meta_data, strategy)
params, producer, input_meta_data, strategy))
steps_per_epoch = None if generate_input_online else num_train_steps steps_per_epoch = None if generate_input_online else num_train_steps
with distribute_utils.get_strategy_scope(strategy): with distribute_utils.get_strategy_scope(strategy):
...@@ -307,7 +306,8 @@ def run_ncf(_): ...@@ -307,7 +306,8 @@ def run_ncf(_):
if not FLAGS.ml_perf: if not FLAGS.ml_perf:
# Create Tensorboard summary and checkpoint callbacks. # Create Tensorboard summary and checkpoint callbacks.
summary_dir = os.path.join(FLAGS.model_dir, "summaries") summary_dir = os.path.join(FLAGS.model_dir, "summaries")
summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) summary_callback = tf.keras.callbacks.TensorBoard(
summary_dir, profile_batch=0)
checkpoint_path = os.path.join(FLAGS.model_dir, "checkpoint") checkpoint_path = os.path.join(FLAGS.model_dir, "checkpoint")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, save_weights_only=True) checkpoint_path, save_weights_only=True)
......
...@@ -38,11 +38,13 @@ class NcfTest(tf.test.TestCase): ...@@ -38,11 +38,13 @@ class NcfTest(tf.test.TestCase):
ncf_common.define_ncf_flags() ncf_common.define_ncf_flags()
def setUp(self): def setUp(self):
super().setUp()
self.top_k_old = rconst.TOP_K self.top_k_old = rconst.TOP_K
self.num_eval_negatives_old = rconst.NUM_EVAL_NEGATIVES self.num_eval_negatives_old = rconst.NUM_EVAL_NEGATIVES
rconst.NUM_EVAL_NEGATIVES = 2 rconst.NUM_EVAL_NEGATIVES = 2
def tearDown(self): def tearDown(self):
super().tearDown()
rconst.NUM_EVAL_NEGATIVES = self.num_eval_negatives_old rconst.NUM_EVAL_NEGATIVES = self.num_eval_negatives_old
rconst.TOP_K = self.top_k_old rconst.TOP_K = self.top_k_old
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment