"vscode:/vscode.git/clone" did not exist on "387ae76df382c67f731da4f2fddd6352ae5eb832"
Commit 0a9026e4 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Improve error message when certain flags are not specified.

In nlp/train.py and vision/beta/train.py, certain flags are marked as required. Additionally, in certain functions, error messages are improved if a necessary flag is not specified, which is a fallback in case a file calling define_flags() does not mark the necessary flags are required. Previously if any of these flags were not specified, it would crash with a cryptic error message, making it hard to tell what went wrong.

In a subsequent change, I will mark flags as required in more files which call define_flags().

PiperOrigin-RevId: 381066985
parent 9f9d07e9
...@@ -18,9 +18,27 @@ from absl import flags ...@@ -18,9 +18,27 @@ from absl import flags
def define_flags(): def define_flags():
"""Defines flags.""" """Defines flags.
All flags are defined as optional, but in practice most models use some of
these flags and so mark_flags_as_required() should be called after calling
this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
For example:
```
from absl import flags
from official.common import flags as tfm_flags # pylint: disable=line-too-long
...
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
```
The reason all flags are optional is because unit tests often do not set or
use any of the flags.
"""
flags.DEFINE_string( flags.DEFINE_string(
'experiment', default=None, help='The experiment type registered.') 'experiment', default=None, help=
'The experiment type registered, specifying an ExperimentConfig.')
flags.DEFINE_enum( flags.DEFINE_enum(
'mode', 'mode',
......
...@@ -78,6 +78,8 @@ def run_experiment( ...@@ -78,6 +78,8 @@ def run_experiment(
params, model_dir)) params, model_dir))
if trainer.checkpoint: if trainer.checkpoint:
if model_dir is None:
raise ValueError('model_dir must be specified, but got None')
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint, trainer.checkpoint,
directory=model_dir, directory=model_dir,
......
...@@ -241,6 +241,9 @@ class ParseConfigOptions: ...@@ -241,6 +241,9 @@ class ParseConfigOptions:
def parse_configuration(flags_obj, lock_return=True, print_return=True): def parse_configuration(flags_obj, lock_return=True, print_return=True):
"""Parses ExperimentConfig from flags.""" """Parses ExperimentConfig from flags."""
if flags_obj.experiment is None:
raise ValueError('The flag --experiment must be specified.')
# 1. Get the default config from the registered experiment. # 1. Get the default config from the registered experiment.
params = exp_factory.get_exp_config(flags_obj.experiment) params = exp_factory.get_exp_config(flags_obj.experiment)
...@@ -285,7 +288,7 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True): ...@@ -285,7 +288,7 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
if print_return: if print_return:
pp = pprint.PrettyPrinter() pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters: %s', logging.info('Final experiment parameters:\n%s',
pp.pformat(params.as_dict())) pp.pformat(params.as_dict()))
return params return params
...@@ -294,6 +297,8 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True): ...@@ -294,6 +297,8 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
def serialize_config(params: config_definitions.ExperimentConfig, def serialize_config(params: config_definitions.ExperimentConfig,
model_dir: str): model_dir: str):
"""Serializes and saves the experiment config.""" """Serializes and saves the experiment config."""
if model_dir is None:
raise ValueError('model_dir must be specified, but got None')
params_save_path = os.path.join(model_dir, 'params.yaml') params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path) logging.info('Saving experiment configuration to %s', params_save_path)
tf.io.gfile.makedirs(model_dir) tf.io.gfile.makedirs(model_dir)
......
...@@ -66,4 +66,5 @@ def main(_): ...@@ -66,4 +66,5 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main) app.run(main)
...@@ -66,4 +66,5 @@ def main(_): ...@@ -66,4 +66,5 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main) app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment