"tests/vscode:/vscode.git/clone" did not exist on "b2ac24560247f10cd36ef1de950a7cc2b2f37448"
Commit a85c40e3 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Unexpose some flags from models which do not use them.

--clean, --train_epochs, and --epochs_between_evals have been unexposed from models which do not use them

PiperOrigin-RevId: 267065651
parent 154e8c46
...@@ -88,7 +88,8 @@ def create_model(data_format): ...@@ -88,7 +88,8 @@ def create_model(data_format):
def define_mnist_flags(): def define_mnist_flags():
flags_core.define_base() flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True)
flags_core.define_performance(inter_op=True, intra_op=True, flags_core.define_performance(inter_op=True, intra_op=True,
num_parallel_calls=False, num_parallel_calls=False,
all_reduce_alg=True) all_reduce_alg=True)
......
...@@ -169,7 +169,7 @@ def run_mnist_eager(flags_obj): ...@@ -169,7 +169,7 @@ def run_mnist_eager(flags_obj):
def define_mnist_eager_flags(): def define_mnist_eager_flags():
"""Defined flags and defaults for MNIST in eager mode.""" """Defined flags and defaults for MNIST in eager mode."""
flags_core.define_base_eager() flags_core.define_base_eager(clean=True, train_epochs=True)
flags_core.define_image() flags_core.define_image()
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
......
...@@ -133,7 +133,7 @@ class BaseTest(tf.test.TestCase): ...@@ -133,7 +133,7 @@ class BaseTest(tf.test.TestCase):
"--eval_start", "12", "--eval_start", "12",
"--eval_count", "8", "--eval_count", "8",
], ],
synth=False) synth=False, train_epochs=None, epochs_between_evals=None)
self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint"))) self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint")))
@unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.") @unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.")
...@@ -152,7 +152,7 @@ class BaseTest(tf.test.TestCase): ...@@ -152,7 +152,7 @@ class BaseTest(tf.test.TestCase):
"--eval_start", "12", "--eval_start", "12",
"--eval_count", "8", "--eval_count", "8",
], ],
synth=False) synth=False, train_epochs=None, epochs_between_evals=None)
self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint"))) self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint")))
self.assertTrue(tf.gfile.Exists(os.path.join(export_dir))) self.assertTrue(tf.gfile.Exists(os.path.join(export_dir)))
......
...@@ -723,7 +723,8 @@ def resnet_main( ...@@ -723,7 +723,8 @@ def resnet_main(
def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False, def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
fp16_implementation=False): fp16_implementation=False):
"""Add flags and validators for ResNet.""" """Add flags and validators for ResNet."""
flags_core.define_base() flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True)
flags_core.define_performance(num_parallel_calls=False, flags_core.define_performance(num_parallel_calls=False,
inter_op=True, inter_op=True,
intra_op=True, intra_op=True,
......
...@@ -36,7 +36,8 @@ LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'} ...@@ -36,7 +36,8 @@ LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
def define_wide_deep_flags(): def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type.""" """Add supervised learning flags, as well as wide-deep model type."""
flags_core.define_base() flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True)
flags_core.define_benchmark() flags_core.define_benchmark()
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=False, inter_op=True, intra_op=True, num_parallel_calls=False, inter_op=True, intra_op=True,
......
...@@ -147,7 +147,9 @@ def get_v1_distribution_strategy(params): ...@@ -147,7 +147,9 @@ def get_v1_distribution_strategy(params):
def define_ncf_flags(): def define_ncf_flags():
"""Add flags for running ncf_main.""" """Add flags for running ncf_main."""
# Add common flags # Add common flags
flags_core.define_base(export_dir=False, run_eagerly=True) flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True, export_dir=False,
run_eagerly=True)
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=False, num_parallel_calls=False,
inter_op=False, inter_op=False,
......
...@@ -394,7 +394,8 @@ def define_transformer_flags(): ...@@ -394,7 +394,8 @@ def define_transformer_flags():
name="max_length", short_name="ml", default=None, name="max_length", short_name="ml", default=None,
help=flags_core.help_wrap("Max length.")) help=flags_core.help_wrap("Max length."))
flags_core.define_base() flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True)
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=True, num_parallel_calls=True,
inter_op=False, inter_op=False,
......
...@@ -60,7 +60,7 @@ def get_model_params(param_set, num_gpus): ...@@ -60,7 +60,7 @@ def get_model_params(param_set, num_gpus):
def define_transformer_flags(): def define_transformer_flags():
"""Add flags and flag validators for running transformer_main.""" """Add flags and flag validators for running transformer_main."""
# Add common flags (data_dir, model_dir, train_epochs, etc.). # Add common flags (data_dir, model_dir, etc.).
flags_core.define_base() flags_core.define_base()
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=True, num_parallel_calls=True,
...@@ -214,18 +214,9 @@ def define_transformer_flags(): ...@@ -214,18 +214,9 @@ def define_transformer_flags():
flags_core.set_defaults(data_dir='/tmp/translate_ende', flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model', model_dir='/tmp/transformer_model',
batch_size=None, batch_size=None)
train_epochs=10)
# pylint: disable=unused-variable # pylint: disable=unused-variable
@flags.multi_flags_validator(
['mode', 'train_epochs'],
message='--train_epochs must be defined in train mode')
def _check_train_limits(flag_dict):
if flag_dict['mode'] == 'train':
return flag_dict['train_epochs'] is not None
return True
@flags.multi_flags_validator( @flags.multi_flags_validator(
['bleu_source', 'bleu_ref'], ['bleu_source', 'bleu_ref'],
message='Both or neither --bleu_source and --bleu_ref must be defined.') message='Both or neither --bleu_source and --bleu_ref must be defined.')
......
...@@ -25,9 +25,9 @@ from official.utils.flags._conventions import help_wrap ...@@ -25,9 +25,9 @@ from official.utils.flags._conventions import help_wrap
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True, def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
epochs_between_evals=True, stop_threshold=True, batch_size=True, epochs_between_evals=False, stop_threshold=True,
num_gpu=True, hooks=True, export_dir=True, batch_size=True, num_gpu=True, hooks=True, export_dir=True,
distribution_strategy=True, run_eagerly=False): distribution_strategy=True, run_eagerly=False):
"""Register base flags. """Register base flags.
......
...@@ -72,8 +72,7 @@ def register_key_flags_in_core(f): ...@@ -72,8 +72,7 @@ def register_key_flags_in_core(f):
define_base = register_key_flags_in_core(_base.define_base) define_base = register_key_flags_in_core(_base.define_base)
# Remove options not relevant for Eager from define_base(). # Remove options not relevant for Eager from define_base().
define_base_eager = register_key_flags_in_core(functools.partial( define_base_eager = register_key_flags_in_core(functools.partial(
_base.define_base, epochs_between_evals=False, stop_threshold=False, _base.define_base, stop_threshold=False, hooks=False))
hooks=False))
define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark) define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
define_device = register_key_flags_in_core(_device.define_device) define_device = register_key_flags_in_core(_device.define_device)
define_image = register_key_flags_in_core(_misc.define_image) define_image = register_key_flags_in_core(_misc.define_image)
......
...@@ -22,7 +22,8 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp ...@@ -22,7 +22,8 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def define_flags(): def define_flags():
flags_core.define_base(num_gpu=False) flags_core.define_base(clean=True, num_gpu=False, train_epochs=True,
epochs_between_evals=True)
flags_core.define_performance( flags_core.define_performance(
num_parallel_calls=True, inter_op=True, intra_op=True, num_parallel_calls=True, inter_op=True, intra_op=True,
dynamic_loss_scale=True, loss_scale=True, synthetic_data=True, dynamic_loss_scale=True, loss_scale=True, synthetic_data=True,
......
...@@ -29,7 +29,8 @@ from absl import flags ...@@ -29,7 +29,8 @@ from absl import flags
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
def run_synthetic(main, tmp_root, extra_flags=None, synth=True): def run_synthetic(main, tmp_root, extra_flags=None, synth=True, train_epochs=1,
epochs_between_evals=1):
"""Performs a minimal run of a model. """Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A This function is intended to test for syntax errors throughout a model. A
...@@ -41,18 +42,25 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True): ...@@ -41,18 +42,25 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True):
tmp_root: Root path for the temp directory created by the test class. tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the caller of this function. extra_flags: Additional flags passed by the caller of this function.
synth: Use synthetic data. synth: Use synthetic data.
train_epochs: Value of the --train_epochs flag.
epochs_between_evals: Value of the --epochs_between_evals flag.
""" """
extra_flags = [] if extra_flags is None else extra_flags extra_flags = [] if extra_flags is None else extra_flags
model_dir = tempfile.mkdtemp(dir=tmp_root) model_dir = tempfile.mkdtemp(dir=tmp_root)
args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1", args = [sys.argv[0], "--model_dir", model_dir] + extra_flags
"--epochs_between_evals", "1"] + extra_flags
if synth: if synth:
args.append("--use_synthetic_data") args.append("--use_synthetic_data")
if train_epochs is not None:
args.extend(["--train_epochs", str(train_epochs)])
if epochs_between_evals is not None:
args.extend(["--epochs_between_evals", str(epochs_between_evals)])
try: try:
flags_core.parse_flags(argv=args) flags_core.parse_flags(argv=args)
main(flags.FLAGS) main(flags.FLAGS)
......
...@@ -283,7 +283,8 @@ def build_stats(history, eval_output, callbacks): ...@@ -283,7 +283,8 @@ def build_stats(history, eval_output, callbacks):
def define_keras_flags(dynamic_loss_scale=True): def define_keras_flags(dynamic_loss_scale=True):
"""Define flags for Keras models.""" """Define flags for Keras models."""
flags_core.define_base(run_eagerly=True) flags_core.define_base(clean=True, run_eagerly=True, train_epochs=True,
epochs_between_evals=True)
flags_core.define_performance(num_parallel_calls=False, flags_core.define_performance(num_parallel_calls=False,
synthetic_data=True, synthetic_data=True,
dtype=True, dtype=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment