Commit 5366f605 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 323948101
parent ccc60760
......@@ -32,7 +32,8 @@ from official.recommendation import movielens
def create_dataset_from_tf_record_files(input_file_pattern,
pre_batch_size,
batch_size,
is_training=True):
is_training=True,
rebatch=False):
"""Creates dataset from (tf)records files for training/evaluation."""
if pre_batch_size != batch_size:
raise ValueError("Pre-batch ({}) size is not equal to batch "
......@@ -51,6 +52,12 @@ def create_dataset_from_tf_record_files(input_file_pattern,
dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if rebatch:
# A workaround for TPU Pod evaluation dataset.
# TODO (b/162341937) remove once it's fixed.
dataset = dataset.unbatch()
dataset = dataset.batch(pre_batch_size)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
......@@ -151,12 +158,18 @@ def create_ncf_input_data(params,
params["train_dataset_path"],
input_meta_data["train_prebatch_size"],
params["batch_size"],
is_training=True)
is_training=True,
rebatch=False)
# Re-batch evaluation dataset for TPU Pods.
# TODO (b/162341937) remove once it's fixed.
eval_rebatch = (params["use_tpu"] and strategy.num_replicas_in_sync > 8)
eval_dataset = create_dataset_from_tf_record_files(
params["eval_dataset_path"],
input_meta_data["eval_prebatch_size"],
params["eval_batch_size"],
is_training=False)
is_training=False,
rebatch=eval_rebatch)
num_train_steps = int(input_meta_data["num_train_steps"])
num_eval_steps = int(input_meta_data["num_eval_steps"])
......
......@@ -235,6 +235,7 @@ def run_ncf(_):
params = ncf_common.parse_flags(FLAGS)
params["distribute_strategy"] = strategy
params["use_tpu"] = (FLAGS.distribution_strategy == "tpu")
if params["use_tpu"] and not params["keras_use_ctl"]:
logging.error("Custom training loop must be used when using TPUStrategy.")
......@@ -491,7 +492,8 @@ def run_ncf_custom_training(params,
logging.info("Done training epoch %s, epoch loss=%.3f", epoch + 1,
train_loss)
eval_input_iterator = iter(eval_input_dataset)
eval_input_iterator = iter(
strategy.experimental_distribute_dataset(eval_input_dataset))
hr_sum = 0.0
hr_count = 0.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment