Commit e91c41c2 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Unexpose some flags from models which do not use them.

--stop_threshold, --num_gpu, --hooks, --export_dir, and --distribution_strategy have been unexposed from models which do not use them

PiperOrigin-RevId: 268032080
parent bb328876
......@@ -88,8 +88,11 @@ def create_model(data_format):
def define_mnist_flags():
"""Defines flags for mnist."""
flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True)
epochs_between_evals=True, stop_threshold=True,
num_gpu=True, hooks=True, export_dir=True,
distribution_strategy=True)
flags_core.define_performance(inter_op=True, intra_op=True,
num_parallel_calls=False,
all_reduce_alg=True)
......
......@@ -169,7 +169,8 @@ def run_mnist_eager(flags_obj):
def define_mnist_eager_flags():
"""Defined flags and defaults for MNIST in eager mode."""
flags_core.define_base_eager(clean=True, train_epochs=True)
flags_core.define_base(clean=True, train_epochs=True, export_dir=True,
distribution_strategy=True)
flags_core.define_image()
flags.adopt_module_key_flags(flags_core)
......
......@@ -260,7 +260,7 @@ def main(_):
def define_train_higgs_flags():
"""Add tree related flags as well as training/eval configuration."""
flags_core.define_base(clean=False, stop_threshold=False, batch_size=False,
num_gpu=False)
num_gpu=False, export_dir=True)
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
......
......@@ -724,7 +724,9 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
fp16_implementation=False):
"""Add flags and validators for ResNet."""
flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True)
epochs_between_evals=True, stop_threshold=True,
num_gpu=True, hooks=True, export_dir=True,
distribution_strategy=True)
flags_core.define_performance(num_parallel_calls=False,
inter_op=True,
intra_op=True,
......
......@@ -37,7 +37,8 @@ LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type."""
flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True)
epochs_between_evals=True, stop_threshold=True,
hooks=True, export_dir=True)
flags_core.define_benchmark()
flags_core.define_performance(
num_parallel_calls=False, inter_op=True, intra_op=True,
......
......@@ -149,7 +149,8 @@ def define_ncf_flags():
# Add common flags
flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True, export_dir=False,
run_eagerly=True)
run_eagerly=True, stop_threshold=True, num_gpu=True,
hooks=True, distribution_strategy=True)
flags_core.define_performance(
num_parallel_calls=False,
inter_op=False,
......
......@@ -46,9 +46,11 @@ def define_flags():
train_epochs=True,
epochs_between_evals=False,
stop_threshold=False,
num_gpu=True,
hooks=False,
export_dir=False,
run_eagerly=True)
run_eagerly=True,
distribution_strategy=True)
flags_core.define_performance(num_parallel_calls=False,
inter_op=False,
......
......@@ -395,7 +395,9 @@ def define_transformer_flags():
help=flags_core.help_wrap("Max length."))
flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True)
epochs_between_evals=True, stop_threshold=True,
num_gpu=True, hooks=True, export_dir=True,
distribution_strategy=True)
flags_core.define_performance(
num_parallel_calls=True,
inter_op=False,
......
......@@ -61,7 +61,7 @@ def get_model_params(param_set, num_gpus):
def define_transformer_flags():
"""Add flags and flag validators for running transformer_main."""
# Add common flags (data_dir, model_dir, etc.).
flags_core.define_base()
flags_core.define_base(num_gpu=True, distribution_strategy=True)
flags_core.define_performance(
num_parallel_calls=True,
inter_op=False,
......@@ -159,15 +159,13 @@ def define_transformer_flags():
help=flags_core.help_wrap(
'Path to source file containing text translate when calculating the '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
'Use the flag --stop_threshold to stop the script based on the '
'uncased BLEU score.'))
))
flags.DEFINE_string(
name='bleu_ref', short_name='blr', default=None,
help=flags_core.help_wrap(
'Path to source file containing text translate when calculating the '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
'Use the flag --stop_threshold to stop the script based on the '
'uncased BLEU score.'))
))
flags.DEFINE_string(
name='vocab_file', short_name='vf', default=None,
help=flags_core.help_wrap(
......@@ -232,14 +230,6 @@ def define_transformer_flags():
if flags_dict['bleu_source'] and flags_dict['bleu_ref']:
return flags_dict['vocab_file'] is not None
return True
@flags.multi_flags_validator(
['export_dir', 'vocab_file'],
message='--vocab_file must be defined if --export_dir is set.')
def _check_export_vocab_file(flags_dict):
if flags_dict['export_dir']:
return flags_dict['vocab_file'] is not None
return True
# pylint: enable=unused-variable
......
......@@ -26,14 +26,15 @@ from official.utils.logs import hooks_helper
def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
epochs_between_evals=False, stop_threshold=True,
batch_size=True, num_gpu=True, hooks=True, export_dir=True,
distribution_strategy=True, run_eagerly=False):
epochs_between_evals=False, stop_threshold=False,
batch_size=True, num_gpu=False, hooks=False, export_dir=False,
distribution_strategy=False, run_eagerly=False):
"""Register base flags.
Args:
data_dir: Create a flag for specifying the input data directory.
model_dir: Create a flag for specifying the model file directory.
clean: Create a flag for removing the model_dir.
train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other
......
......@@ -21,7 +21,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import sys
from six.moves import shlex_quote
......@@ -70,9 +69,9 @@ def register_key_flags_in_core(f):
define_base = register_key_flags_in_core(_base.define_base)
# Remove options not relevant for Eager from define_base().
define_base_eager = register_key_flags_in_core(functools.partial(
_base.define_base, stop_threshold=False, hooks=False))
# We have define_base_eager for compatibility, since it used to be a separate
# function from define_base.
define_base_eager = define_base
define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
define_device = register_key_flags_in_core(_device.define_device)
define_image = register_key_flags_in_core(_misc.define_image)
......
......@@ -22,7 +22,8 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def define_flags():
flags_core.define_base(clean=True, num_gpu=False, train_epochs=True,
flags_core.define_base(clean=True, num_gpu=False, stop_threshold=True,
hooks=True, train_epochs=True,
epochs_between_evals=True)
flags_core.define_performance(
num_parallel_calls=True, inter_op=True, intra_op=True,
......
......@@ -283,8 +283,9 @@ def build_stats(history, eval_output, callbacks):
def define_keras_flags(dynamic_loss_scale=True):
"""Define flags for Keras models."""
flags_core.define_base(clean=True, run_eagerly=True, train_epochs=True,
epochs_between_evals=True)
flags_core.define_base(clean=True, num_gpu=True, run_eagerly=True,
train_epochs=True, epochs_between_evals=True,
distribution_strategy=True)
flags_core.define_performance(num_parallel_calls=False,
synthetic_data=True,
dtype=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment