Commit e91c41c2 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Unexpose some flags from models which do not use them.

--stop_threshold, --num_gpu, --hooks, --export_dir, and --distribution_strategy have been unexposed from models which do not use them

PiperOrigin-RevId: 268032080
parent bb328876
...@@ -88,8 +88,11 @@ def create_model(data_format): ...@@ -88,8 +88,11 @@ def create_model(data_format):
def define_mnist_flags(): def define_mnist_flags():
"""Defines flags for mnist."""
flags_core.define_base(clean=True, train_epochs=True, flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True) epochs_between_evals=True, stop_threshold=True,
num_gpu=True, hooks=True, export_dir=True,
distribution_strategy=True)
flags_core.define_performance(inter_op=True, intra_op=True, flags_core.define_performance(inter_op=True, intra_op=True,
num_parallel_calls=False, num_parallel_calls=False,
all_reduce_alg=True) all_reduce_alg=True)
......
...@@ -169,7 +169,8 @@ def run_mnist_eager(flags_obj): ...@@ -169,7 +169,8 @@ def run_mnist_eager(flags_obj):
def define_mnist_eager_flags(): def define_mnist_eager_flags():
"""Defined flags and defaults for MNIST in eager mode.""" """Defined flags and defaults for MNIST in eager mode."""
flags_core.define_base_eager(clean=True, train_epochs=True) flags_core.define_base(clean=True, train_epochs=True, export_dir=True,
distribution_strategy=True)
flags_core.define_image() flags_core.define_image()
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
......
...@@ -260,7 +260,7 @@ def main(_): ...@@ -260,7 +260,7 @@ def main(_):
def define_train_higgs_flags(): def define_train_higgs_flags():
"""Add tree related flags as well as training/eval configuration.""" """Add tree related flags as well as training/eval configuration."""
flags_core.define_base(clean=False, stop_threshold=False, batch_size=False, flags_core.define_base(clean=False, stop_threshold=False, batch_size=False,
num_gpu=False) num_gpu=False, export_dir=True)
flags_core.define_benchmark() flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
......
...@@ -724,7 +724,9 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False, ...@@ -724,7 +724,9 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
fp16_implementation=False): fp16_implementation=False):
"""Add flags and validators for ResNet.""" """Add flags and validators for ResNet."""
flags_core.define_base(clean=True, train_epochs=True, flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True) epochs_between_evals=True, stop_threshold=True,
num_gpu=True, hooks=True, export_dir=True,
distribution_strategy=True)
flags_core.define_performance(num_parallel_calls=False, flags_core.define_performance(num_parallel_calls=False,
inter_op=True, inter_op=True,
intra_op=True, intra_op=True,
......
...@@ -37,7 +37,8 @@ LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'} ...@@ -37,7 +37,8 @@ LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
def define_wide_deep_flags(): def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type.""" """Add supervised learning flags, as well as wide-deep model type."""
flags_core.define_base(clean=True, train_epochs=True, flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True) epochs_between_evals=True, stop_threshold=True,
hooks=True, export_dir=True)
flags_core.define_benchmark() flags_core.define_benchmark()
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=False, inter_op=True, intra_op=True, num_parallel_calls=False, inter_op=True, intra_op=True,
......
...@@ -149,7 +149,8 @@ def define_ncf_flags(): ...@@ -149,7 +149,8 @@ def define_ncf_flags():
# Add common flags # Add common flags
flags_core.define_base(clean=True, train_epochs=True, flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True, export_dir=False, epochs_between_evals=True, export_dir=False,
run_eagerly=True) run_eagerly=True, stop_threshold=True, num_gpu=True,
hooks=True, distribution_strategy=True)
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=False, num_parallel_calls=False,
inter_op=False, inter_op=False,
......
...@@ -46,9 +46,11 @@ def define_flags(): ...@@ -46,9 +46,11 @@ def define_flags():
train_epochs=True, train_epochs=True,
epochs_between_evals=False, epochs_between_evals=False,
stop_threshold=False, stop_threshold=False,
num_gpu=True,
hooks=False, hooks=False,
export_dir=False, export_dir=False,
run_eagerly=True) run_eagerly=True,
distribution_strategy=True)
flags_core.define_performance(num_parallel_calls=False, flags_core.define_performance(num_parallel_calls=False,
inter_op=False, inter_op=False,
......
...@@ -395,7 +395,9 @@ def define_transformer_flags(): ...@@ -395,7 +395,9 @@ def define_transformer_flags():
help=flags_core.help_wrap("Max length.")) help=flags_core.help_wrap("Max length."))
flags_core.define_base(clean=True, train_epochs=True, flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True) epochs_between_evals=True, stop_threshold=True,
num_gpu=True, hooks=True, export_dir=True,
distribution_strategy=True)
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=True, num_parallel_calls=True,
inter_op=False, inter_op=False,
......
...@@ -61,7 +61,7 @@ def get_model_params(param_set, num_gpus): ...@@ -61,7 +61,7 @@ def get_model_params(param_set, num_gpus):
def define_transformer_flags(): def define_transformer_flags():
"""Add flags and flag validators for running transformer_main.""" """Add flags and flag validators for running transformer_main."""
# Add common flags (data_dir, model_dir, etc.). # Add common flags (data_dir, model_dir, etc.).
flags_core.define_base() flags_core.define_base(num_gpu=True, distribution_strategy=True)
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=True, num_parallel_calls=True,
inter_op=False, inter_op=False,
...@@ -159,15 +159,13 @@ def define_transformer_flags(): ...@@ -159,15 +159,13 @@ def define_transformer_flags():
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Path to source file containing text translate when calculating the ' 'Path to source file containing text translate when calculating the '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. ' 'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
'Use the flag --stop_threshold to stop the script based on the ' ))
'uncased BLEU score.'))
flags.DEFINE_string( flags.DEFINE_string(
name='bleu_ref', short_name='blr', default=None, name='bleu_ref', short_name='blr', default=None,
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Path to source file containing text translate when calculating the ' 'Path to source file containing text translate when calculating the '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. ' 'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
'Use the flag --stop_threshold to stop the script based on the ' ))
'uncased BLEU score.'))
flags.DEFINE_string( flags.DEFINE_string(
name='vocab_file', short_name='vf', default=None, name='vocab_file', short_name='vf', default=None,
help=flags_core.help_wrap( help=flags_core.help_wrap(
...@@ -232,14 +230,6 @@ def define_transformer_flags(): ...@@ -232,14 +230,6 @@ def define_transformer_flags():
if flags_dict['bleu_source'] and flags_dict['bleu_ref']: if flags_dict['bleu_source'] and flags_dict['bleu_ref']:
return flags_dict['vocab_file'] is not None return flags_dict['vocab_file'] is not None
return True return True
@flags.multi_flags_validator(
['export_dir', 'vocab_file'],
message='--vocab_file must be defined if --export_dir is set.')
def _check_export_vocab_file(flags_dict):
if flags_dict['export_dir']:
return flags_dict['vocab_file'] is not None
return True
# pylint: enable=unused-variable # pylint: enable=unused-variable
......
...@@ -26,14 +26,15 @@ from official.utils.logs import hooks_helper ...@@ -26,14 +26,15 @@ from official.utils.logs import hooks_helper
def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False, def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
epochs_between_evals=False, stop_threshold=True, epochs_between_evals=False, stop_threshold=False,
batch_size=True, num_gpu=True, hooks=True, export_dir=True, batch_size=True, num_gpu=False, hooks=False, export_dir=False,
distribution_strategy=True, run_eagerly=False): distribution_strategy=False, run_eagerly=False):
"""Register base flags. """Register base flags.
Args: Args:
data_dir: Create a flag for specifying the input data directory. data_dir: Create a flag for specifying the input data directory.
model_dir: Create a flag for specifying the model file directory. model_dir: Create a flag for specifying the model file directory.
clean: Create a flag for removing the model_dir.
train_epochs: Create a flag to specify the number of training epochs. train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing. epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other stop_threshold: Create a flag to specify a threshold accuracy or other
......
...@@ -21,7 +21,6 @@ from __future__ import absolute_import ...@@ -21,7 +21,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import sys import sys
from six.moves import shlex_quote from six.moves import shlex_quote
...@@ -70,9 +69,9 @@ def register_key_flags_in_core(f): ...@@ -70,9 +69,9 @@ def register_key_flags_in_core(f):
define_base = register_key_flags_in_core(_base.define_base) define_base = register_key_flags_in_core(_base.define_base)
# Remove options not relevant for Eager from define_base(). # We have define_base_eager for compatibility, since it used to be a separate
define_base_eager = register_key_flags_in_core(functools.partial( # function from define_base.
_base.define_base, stop_threshold=False, hooks=False)) define_base_eager = define_base
define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark) define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
define_device = register_key_flags_in_core(_device.define_device) define_device = register_key_flags_in_core(_device.define_device)
define_image = register_key_flags_in_core(_misc.define_image) define_image = register_key_flags_in_core(_misc.define_image)
......
...@@ -22,7 +22,8 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp ...@@ -22,7 +22,8 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def define_flags(): def define_flags():
flags_core.define_base(clean=True, num_gpu=False, train_epochs=True, flags_core.define_base(clean=True, num_gpu=False, stop_threshold=True,
hooks=True, train_epochs=True,
epochs_between_evals=True) epochs_between_evals=True)
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=True, inter_op=True, intra_op=True, num_parallel_calls=True, inter_op=True, intra_op=True,
......
...@@ -283,8 +283,9 @@ def build_stats(history, eval_output, callbacks): ...@@ -283,8 +283,9 @@ def build_stats(history, eval_output, callbacks):
def define_keras_flags(dynamic_loss_scale=True): def define_keras_flags(dynamic_loss_scale=True):
"""Define flags for Keras models.""" """Define flags for Keras models."""
flags_core.define_base(clean=True, run_eagerly=True, train_epochs=True, flags_core.define_base(clean=True, num_gpu=True, run_eagerly=True,
epochs_between_evals=True) train_epochs=True, epochs_between_evals=True,
distribution_strategy=True)
flags_core.define_performance(num_parallel_calls=False, flags_core.define_performance(num_parallel_calls=False,
synthetic_data=True, synthetic_data=True,
dtype=True, dtype=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment