Commit 5366f605 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 323948101
parent ccc60760
...@@ -32,7 +32,8 @@ from official.recommendation import movielens ...@@ -32,7 +32,8 @@ from official.recommendation import movielens
def create_dataset_from_tf_record_files(input_file_pattern, def create_dataset_from_tf_record_files(input_file_pattern,
pre_batch_size, pre_batch_size,
batch_size, batch_size,
is_training=True): is_training=True,
rebatch=False):
"""Creates dataset from (tf)records files for training/evaluation.""" """Creates dataset from (tf)records files for training/evaluation."""
if pre_batch_size != batch_size: if pre_batch_size != batch_size:
raise ValueError("Pre-batch ({}) size is not equal to batch " raise ValueError("Pre-batch ({}) size is not equal to batch "
...@@ -51,6 +52,12 @@ def create_dataset_from_tf_record_files(input_file_pattern, ...@@ -51,6 +52,12 @@ def create_dataset_from_tf_record_files(input_file_pattern,
dataset = dataset.map( dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if rebatch:
# A workaround for TPU Pod evaluation dataset.
# TODO (b/162341937) remove once it's fixed.
dataset = dataset.unbatch()
dataset = dataset.batch(pre_batch_size)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset return dataset
...@@ -151,12 +158,18 @@ def create_ncf_input_data(params, ...@@ -151,12 +158,18 @@ def create_ncf_input_data(params,
params["train_dataset_path"], params["train_dataset_path"],
input_meta_data["train_prebatch_size"], input_meta_data["train_prebatch_size"],
params["batch_size"], params["batch_size"],
is_training=True) is_training=True,
rebatch=False)
# Re-batch evaluation dataset for TPU Pods.
# TODO (b/162341937) remove once it's fixed.
eval_rebatch = (params["use_tpu"] and strategy.num_replicas_in_sync > 8)
eval_dataset = create_dataset_from_tf_record_files( eval_dataset = create_dataset_from_tf_record_files(
params["eval_dataset_path"], params["eval_dataset_path"],
input_meta_data["eval_prebatch_size"], input_meta_data["eval_prebatch_size"],
params["eval_batch_size"], params["eval_batch_size"],
is_training=False) is_training=False,
rebatch=eval_rebatch)
num_train_steps = int(input_meta_data["num_train_steps"]) num_train_steps = int(input_meta_data["num_train_steps"])
num_eval_steps = int(input_meta_data["num_eval_steps"]) num_eval_steps = int(input_meta_data["num_eval_steps"])
......
...@@ -235,6 +235,7 @@ def run_ncf(_): ...@@ -235,6 +235,7 @@ def run_ncf(_):
params = ncf_common.parse_flags(FLAGS) params = ncf_common.parse_flags(FLAGS)
params["distribute_strategy"] = strategy params["distribute_strategy"] = strategy
params["use_tpu"] = (FLAGS.distribution_strategy == "tpu")
if params["use_tpu"] and not params["keras_use_ctl"]: if params["use_tpu"] and not params["keras_use_ctl"]:
logging.error("Custom training loop must be used when using TPUStrategy.") logging.error("Custom training loop must be used when using TPUStrategy.")
...@@ -491,7 +492,8 @@ def run_ncf_custom_training(params, ...@@ -491,7 +492,8 @@ def run_ncf_custom_training(params,
logging.info("Done training epoch %s, epoch loss=%.3f", epoch + 1, logging.info("Done training epoch %s, epoch loss=%.3f", epoch + 1,
train_loss) train_loss)
eval_input_iterator = iter(eval_input_dataset) eval_input_iterator = iter(
strategy.experimental_distribute_dataset(eval_input_dataset))
hr_sum = 0.0 hr_sum = 0.0
hr_count = 0.0 hr_count = 0.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment