Unverified Commit 272a2baa authored by rxsang's avatar rxsang Committed by GitHub
Browse files

Add enable_get_next_as_optional flag. (#6858)

* Add enable_get_next_as_optional flag.

* Set enable_get_next_as_optional to strategy.

* Add comments to explain the flag.

* Remove trailing whitespace.

* Remove trailing space.
parent 3a97b68c
...@@ -372,6 +372,9 @@ def define_keras_flags(): ...@@ -372,6 +372,9 @@ def define_keras_flags():
name='clone_model_in_keras_dist_strat', default=None, name='clone_model_in_keras_dist_strat', default=None,
help='If False, then the experimental code path is used that doesn\'t ' help='If False, then the experimental code path is used that doesn\'t '
'clone models for distribution.') 'clone models for distribution.')
flags.DEFINE_boolean(
name='enable_get_next_as_optional', default=False,
help='Enable get_next_as_optional behavior in DistributedIterator.')
def get_synth_input_fn(height, width, num_channels, num_classes, def get_synth_input_fn(height, width, num_channels, num_classes,
......
...@@ -129,6 +129,13 @@ def run(flags_obj): ...@@ -129,6 +129,13 @@ def run(flags_obj):
all_reduce_alg=flags_obj.all_reduce_alg, all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs) num_packs=flags_obj.num_packs)
# flags_obj.enable_get_next_as_optional controls whether enabling
# get_next_as_optional behavior in DistributedIterator. If true, last partial
# batch can be supported.
strategy.extended.experimental_enable_get_next_as_optional = (
flags_obj.enable_get_next_as_optional
)
strategy_scope = distribution_utils.get_strategy_scope(strategy) strategy_scope = distribution_utils.get_strategy_scope(strategy)
# pylint: disable=protected-access # pylint: disable=protected-access
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment