Commit 4cac5784 authored by Dinghua Li's avatar Dinghua Li Committed by A. Unique TensorFlower
Browse files

Allow passing a `tf.distribute.InputOptions` to...

Allow passing a `tf.distribute.InputOptions` to `orbit.make_distributed_dataset` through the `input_options` key word argument.

PiperOrigin-RevId: 404597006
parent 558c31ff
......@@ -62,8 +62,11 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
if strategy is None:
strategy = tf.distribute.get_strategy()
input_options = kwargs.get("input_options", None)
if isinstance(dataset_or_fn, tf.data.Dataset):
return strategy.experimental_distribute_dataset(dataset_or_fn)
return strategy.experimental_distribute_dataset(dataset_or_fn,
input_options)
if not callable(dataset_or_fn):
raise ValueError("`dataset_or_fn` should be either callable or an instance "
......@@ -82,7 +85,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
kwargs["input_context"] = input_context
return dataset_or_fn(*args, **kwargs)
return strategy.distribute_datasets_from_function(dataset_fn)
return strategy.distribute_datasets_from_function(dataset_fn, input_options)
def get_value(x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment