"awq/vscode:/vscode.git/clone" did not exist on "b190df3597604c4c97ef4446536b1c8247b26812"
Commit a85c40e3 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Unexpose some flags from models which do not use them.

--clean, --train_epochs, and --epochs_between_evals have been unexposed from models which do not use them

PiperOrigin-RevId: 267065651
parent 154e8c46
......@@ -88,7 +88,8 @@ def create_model(data_format):
def define_mnist_flags():
flags_core.define_base()
flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True)
flags_core.define_performance(inter_op=True, intra_op=True,
num_parallel_calls=False,
all_reduce_alg=True)
......
......@@ -169,7 +169,7 @@ def run_mnist_eager(flags_obj):
def define_mnist_eager_flags():
"""Defined flags and defaults for MNIST in eager mode."""
flags_core.define_base_eager()
flags_core.define_base_eager(clean=True, train_epochs=True)
flags_core.define_image()
flags.adopt_module_key_flags(flags_core)
......
......@@ -133,7 +133,7 @@ class BaseTest(tf.test.TestCase):
"--eval_start", "12",
"--eval_count", "8",
],
synth=False)
synth=False, train_epochs=None, epochs_between_evals=None)
self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint")))
@unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.")
......@@ -152,7 +152,7 @@ class BaseTest(tf.test.TestCase):
"--eval_start", "12",
"--eval_count", "8",
],
synth=False)
synth=False, train_epochs=None, epochs_between_evals=None)
self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint")))
self.assertTrue(tf.gfile.Exists(os.path.join(export_dir)))
......
......@@ -723,7 +723,8 @@ def resnet_main(
def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
fp16_implementation=False):
"""Add flags and validators for ResNet."""
flags_core.define_base()
flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True)
flags_core.define_performance(num_parallel_calls=False,
inter_op=True,
intra_op=True,
......
......@@ -36,7 +36,8 @@ LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type."""
flags_core.define_base()
flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True)
flags_core.define_benchmark()
flags_core.define_performance(
num_parallel_calls=False, inter_op=True, intra_op=True,
......
......@@ -147,7 +147,9 @@ def get_v1_distribution_strategy(params):
def define_ncf_flags():
"""Add flags for running ncf_main."""
# Add common flags
flags_core.define_base(export_dir=False, run_eagerly=True)
flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True, export_dir=False,
run_eagerly=True)
flags_core.define_performance(
num_parallel_calls=False,
inter_op=False,
......
......@@ -394,7 +394,8 @@ def define_transformer_flags():
name="max_length", short_name="ml", default=None,
help=flags_core.help_wrap("Max length."))
flags_core.define_base()
flags_core.define_base(clean=True, train_epochs=True,
epochs_between_evals=True)
flags_core.define_performance(
num_parallel_calls=True,
inter_op=False,
......
......@@ -60,7 +60,7 @@ def get_model_params(param_set, num_gpus):
def define_transformer_flags():
"""Add flags and flag validators for running transformer_main."""
# Add common flags (data_dir, model_dir, train_epochs, etc.).
# Add common flags (data_dir, model_dir, etc.).
flags_core.define_base()
flags_core.define_performance(
num_parallel_calls=True,
......@@ -214,18 +214,9 @@ def define_transformer_flags():
flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model',
batch_size=None,
train_epochs=10)
batch_size=None)
# pylint: disable=unused-variable
@flags.multi_flags_validator(
['mode', 'train_epochs'],
message='--train_epochs must be defined in train mode')
def _check_train_limits(flag_dict):
if flag_dict['mode'] == 'train':
return flag_dict['train_epochs'] is not None
return True
@flags.multi_flags_validator(
['bleu_source', 'bleu_ref'],
message='Both or neither --bleu_source and --bleu_ref must be defined.')
......
......@@ -25,9 +25,9 @@ from official.utils.flags._conventions import help_wrap
from official.utils.logs import hooks_helper
def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True,
epochs_between_evals=True, stop_threshold=True, batch_size=True,
num_gpu=True, hooks=True, export_dir=True,
def define_base(data_dir=True, model_dir=True, clean=False, train_epochs=False,
epochs_between_evals=False, stop_threshold=True,
batch_size=True, num_gpu=True, hooks=True, export_dir=True,
distribution_strategy=True, run_eagerly=False):
"""Register base flags.
......
......@@ -72,8 +72,7 @@ def register_key_flags_in_core(f):
define_base = register_key_flags_in_core(_base.define_base)
# Remove options not relevant for Eager from define_base().
define_base_eager = register_key_flags_in_core(functools.partial(
_base.define_base, epochs_between_evals=False, stop_threshold=False,
hooks=False))
_base.define_base, stop_threshold=False, hooks=False))
define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
define_device = register_key_flags_in_core(_device.define_device)
define_image = register_key_flags_in_core(_misc.define_image)
......
......@@ -22,7 +22,8 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def define_flags():
flags_core.define_base(num_gpu=False)
flags_core.define_base(clean=True, num_gpu=False, train_epochs=True,
epochs_between_evals=True)
flags_core.define_performance(
num_parallel_calls=True, inter_op=True, intra_op=True,
dynamic_loss_scale=True, loss_scale=True, synthetic_data=True,
......
......@@ -29,7 +29,8 @@ from absl import flags
from official.utils.flags import core as flags_core
def run_synthetic(main, tmp_root, extra_flags=None, synth=True):
def run_synthetic(main, tmp_root, extra_flags=None, synth=True, train_epochs=1,
epochs_between_evals=1):
"""Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A
......@@ -41,18 +42,25 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True):
tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the caller of this function.
synth: Use synthetic data.
train_epochs: Value of the --train_epochs flag.
epochs_between_evals: Value of the --epochs_between_evals flag.
"""
extra_flags = [] if extra_flags is None else extra_flags
model_dir = tempfile.mkdtemp(dir=tmp_root)
args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1",
"--epochs_between_evals", "1"] + extra_flags
args = [sys.argv[0], "--model_dir", model_dir] + extra_flags
if synth:
args.append("--use_synthetic_data")
if train_epochs is not None:
args.extend(["--train_epochs", str(train_epochs)])
if epochs_between_evals is not None:
args.extend(["--epochs_between_evals", str(epochs_between_evals)])
try:
flags_core.parse_flags(argv=args)
main(flags.FLAGS)
......
......@@ -283,7 +283,8 @@ def build_stats(history, eval_output, callbacks):
def define_keras_flags(dynamic_loss_scale=True):
"""Define flags for Keras models."""
flags_core.define_base(run_eagerly=True)
flags_core.define_base(clean=True, run_eagerly=True, train_epochs=True,
epochs_between_evals=True)
flags_core.define_performance(num_parallel_calls=False,
synthetic_data=True,
dtype=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment