"vscode:/vscode.git/clone" did not exist on "1189c66c1bcf37eb8b888034ec8f1a9551393ea3"
Commit 0a9026e4 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Improve error message when certain flags are not specified.

In nlp/train.py and vision/beta/train.py, certain flags are marked as required. Additionally, in certain functions, error messages are improved if a necessary flag is not specified, which is a fallback in case a file calling define_flags() does not mark the necessary flags are required. Previously if any of these flags were not specified, it would crash with a cryptic error message, making it hard to tell what went wrong.

In a subsequent change, I will mark flags as required in more files which call define_flags().

PiperOrigin-RevId: 381066985
parent 9f9d07e9
......@@ -18,9 +18,27 @@ from absl import flags
def define_flags():
"""Defines flags."""
"""Defines flags.
All flags are defined as optional, but in practice most models use some of
these flags and so mark_flags_as_required() should be called after calling
this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
For example:
```
from absl import flags
from official.common import flags as tfm_flags # pylint: disable=line-too-long
...
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
```
The reason all flags are optional is because unit tests often do not set or
use any of the flags.
"""
flags.DEFINE_string(
'experiment', default=None, help='The experiment type registered.')
'experiment', default=None, help=
'The experiment type registered, specifying an ExperimentConfig.')
flags.DEFINE_enum(
'mode',
......
......@@ -78,6 +78,8 @@ def run_experiment(
params, model_dir))
if trainer.checkpoint:
if model_dir is None:
raise ValueError('model_dir must be specified, but got None')
checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint,
directory=model_dir,
......
......@@ -241,6 +241,9 @@ class ParseConfigOptions:
def parse_configuration(flags_obj, lock_return=True, print_return=True):
"""Parses ExperimentConfig from flags."""
if flags_obj.experiment is None:
raise ValueError('The flag --experiment must be specified.')
# 1. Get the default config from the registered experiment.
params = exp_factory.get_exp_config(flags_obj.experiment)
......@@ -285,7 +288,7 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
if print_return:
pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters: %s',
logging.info('Final experiment parameters:\n%s',
pp.pformat(params.as_dict()))
return params
......@@ -294,6 +297,8 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
def serialize_config(params: config_definitions.ExperimentConfig,
model_dir: str):
"""Serializes and saves the experiment config."""
if model_dir is None:
raise ValueError('model_dir must be specified, but got None')
params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path)
tf.io.gfile.makedirs(model_dir)
......
......@@ -66,4 +66,5 @@ def main(_):
if __name__ == '__main__':
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main)
......@@ -66,4 +66,5 @@ def main(_):
if __name__ == '__main__':
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment