Commit 9caaba0c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

[ranking] Passing `experimental_prefetch_to_device=False` option for...

[ranking] Passing `experimental_prefetch_to_device=False` option for distributed dataset of Ranking models.

PiperOrigin-RevId: 384792951
parent 552a3baa
......@@ -20,7 +20,6 @@ from absl import app
from absl import flags
from absl import logging
import orbit
import tensorflow as tf
from official.common import distribute_utils
......@@ -95,6 +94,21 @@ def main(_) -> None:
with strategy.scope():
model = task.build_model()
def get_dataset_fn(params):
return lambda input_context: task.build_inputs(params, input_context)
train_dataset = None
if 'train' in mode:
train_dataset = strategy.distribute_datasets_from_function(
get_dataset_fn(params.task.train_data),
options=tf.distribute.InputOptions(experimental_fetch_to_device=False))
validation_dataset = None
if 'eval' in mode:
validation_dataset = strategy.distribute_datasets_from_function(
get_dataset_fn(params.task.validation_data),
options=tf.distribute.InputOptions(experimental_fetch_to_device=False))
if params.trainer.use_orbit:
with strategy.scope():
checkpoint_exporter = train_utils.maybe_create_best_ckpt_exporter(
......@@ -106,6 +120,8 @@ def main(_) -> None:
optimizer=model.optimizer,
train='train' in mode,
evaluate='eval' in mode,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
checkpoint_exporter=checkpoint_exporter)
train_lib.run_experiment(
......@@ -117,16 +133,6 @@ def main(_) -> None:
trainer=trainer)
else: # Compile/fit
train_dataset = None
if 'train' in mode:
train_dataset = orbit.utils.make_distributed_dataset(
strategy, task.build_inputs, params.task.train_data)
eval_dataset = None
if 'eval' in mode:
eval_dataset = orbit.utils.make_distributed_dataset(
strategy, task.build_inputs, params.task.validation_data)
checkpoint = tf.train.Checkpoint(model=model, optimizer=model.optimizer)
latest_checkpoint = tf.train.latest_checkpoint(model_dir)
......@@ -169,7 +175,7 @@ def main(_) -> None:
initial_epoch=initial_epoch,
epochs=num_epochs,
steps_per_epoch=params.trainer.validation_interval,
validation_data=eval_dataset,
validation_data=validation_dataset,
validation_steps=eval_steps,
callbacks=callbacks,
)
......@@ -177,7 +183,7 @@ def main(_) -> None:
logging.info('Train history: %s', history.history)
elif mode == 'eval':
logging.info('Evaluation started')
validation_output = model.evaluate(eval_dataset, steps=eval_steps)
validation_output = model.evaluate(validation_dataset, steps=eval_steps)
logging.info('Evaluation output: %s', validation_output)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment