Commit 9caaba0c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

[ranking] Passing `experimental_prefetch_to_device=False` option for...

[ranking] Passing `experimental_prefetch_to_device=False` option for distributed dataset of Ranking models.

PiperOrigin-RevId: 384792951
parent 552a3baa
...@@ -20,7 +20,6 @@ from absl import app ...@@ -20,7 +20,6 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import orbit
import tensorflow as tf import tensorflow as tf
from official.common import distribute_utils from official.common import distribute_utils
...@@ -95,6 +94,21 @@ def main(_) -> None: ...@@ -95,6 +94,21 @@ def main(_) -> None:
with strategy.scope(): with strategy.scope():
model = task.build_model() model = task.build_model()
def get_dataset_fn(params):
return lambda input_context: task.build_inputs(params, input_context)
train_dataset = None
if 'train' in mode:
train_dataset = strategy.distribute_datasets_from_function(
get_dataset_fn(params.task.train_data),
options=tf.distribute.InputOptions(experimental_fetch_to_device=False))
validation_dataset = None
if 'eval' in mode:
validation_dataset = strategy.distribute_datasets_from_function(
get_dataset_fn(params.task.validation_data),
options=tf.distribute.InputOptions(experimental_fetch_to_device=False))
if params.trainer.use_orbit: if params.trainer.use_orbit:
with strategy.scope(): with strategy.scope():
checkpoint_exporter = train_utils.maybe_create_best_ckpt_exporter( checkpoint_exporter = train_utils.maybe_create_best_ckpt_exporter(
...@@ -106,6 +120,8 @@ def main(_) -> None: ...@@ -106,6 +120,8 @@ def main(_) -> None:
optimizer=model.optimizer, optimizer=model.optimizer,
train='train' in mode, train='train' in mode,
evaluate='eval' in mode, evaluate='eval' in mode,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
checkpoint_exporter=checkpoint_exporter) checkpoint_exporter=checkpoint_exporter)
train_lib.run_experiment( train_lib.run_experiment(
...@@ -117,16 +133,6 @@ def main(_) -> None: ...@@ -117,16 +133,6 @@ def main(_) -> None:
trainer=trainer) trainer=trainer)
else: # Compile/fit else: # Compile/fit
train_dataset = None
if 'train' in mode:
train_dataset = orbit.utils.make_distributed_dataset(
strategy, task.build_inputs, params.task.train_data)
eval_dataset = None
if 'eval' in mode:
eval_dataset = orbit.utils.make_distributed_dataset(
strategy, task.build_inputs, params.task.validation_data)
checkpoint = tf.train.Checkpoint(model=model, optimizer=model.optimizer) checkpoint = tf.train.Checkpoint(model=model, optimizer=model.optimizer)
latest_checkpoint = tf.train.latest_checkpoint(model_dir) latest_checkpoint = tf.train.latest_checkpoint(model_dir)
...@@ -169,7 +175,7 @@ def main(_) -> None: ...@@ -169,7 +175,7 @@ def main(_) -> None:
initial_epoch=initial_epoch, initial_epoch=initial_epoch,
epochs=num_epochs, epochs=num_epochs,
steps_per_epoch=params.trainer.validation_interval, steps_per_epoch=params.trainer.validation_interval,
validation_data=eval_dataset, validation_data=validation_dataset,
validation_steps=eval_steps, validation_steps=eval_steps,
callbacks=callbacks, callbacks=callbacks,
) )
...@@ -177,7 +183,7 @@ def main(_) -> None: ...@@ -177,7 +183,7 @@ def main(_) -> None:
logging.info('Train history: %s', history.history) logging.info('Train history: %s', history.history)
elif mode == 'eval': elif mode == 'eval':
logging.info('Evaluation started') logging.info('Evaluation started')
validation_output = model.evaluate(eval_dataset, steps=eval_steps) validation_output = model.evaluate(validation_dataset, steps=eval_steps)
logging.info('Evaluation output: %s', validation_output) logging.info('Evaluation output: %s', validation_output)
else: else:
raise NotImplementedError('The mode is not implemented: %s' % mode) raise NotImplementedError('The mode is not implemented: %s' % mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment