Commit d93a1206 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Move gin flags to hyperparams_flags are they are in the same category and we...

Move gin flags to hyperparams_flags are they are in the same category and we will use them more widely.

PiperOrigin-RevId: 309779408
parent e91779b8
...@@ -17,17 +17,10 @@ ...@@ -17,17 +17,10 @@
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.utils import hyperparams_flags
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
def define_gin_flags():
"""Define common gin configurable flags."""
flags.DEFINE_multi_string('gin_file', None,
'List of paths to the config files.')
flags.DEFINE_multi_string(
'gin_param', None, 'Newline separated list of Gin parameter bindings.')
def define_common_bert_flags(): def define_common_bert_flags():
"""Define common flags for BERT tasks.""" """Define common flags for BERT tasks."""
flags_core.define_base( flags_core.define_base(
...@@ -100,6 +93,9 @@ def define_common_bert_flags(): ...@@ -100,6 +93,9 @@ def define_common_bert_flags():
fp16_implementation=True, fp16_implementation=True,
) )
# Adds gin configuration flags.
hyperparams_flags.define_gin_flags()
def dtype(): def dtype():
return flags_core.get_tf_dtype(flags.FLAGS) return flags_core.get_tf_dtype(flags.FLAGS)
......
...@@ -57,7 +57,6 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.') ...@@ -57,7 +57,6 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.') flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -51,7 +51,6 @@ flags.DEFINE_bool('use_next_sentence_label', True, ...@@ -51,7 +51,6 @@ flags.DEFINE_bool('use_next_sentence_label', True,
'Whether to use next sentence label to compute final loss.') 'Whether to use next sentence label to compute final loss.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -88,7 +88,6 @@ def define_common_squad_flags(): ...@@ -88,7 +88,6 @@ def define_common_squad_flags():
'another.') 'another.')
common_flags.define_common_bert_flags() common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -25,6 +25,14 @@ from official.utils.flags import core as flags_core ...@@ -25,6 +25,14 @@ from official.utils.flags import core as flags_core
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def define_gin_flags():
"""Define common gin configurable flags."""
flags.DEFINE_multi_string('gin_file', None,
'List of paths to the config files.')
flags.DEFINE_multi_string(
'gin_param', None, 'Newline separated list of Gin parameter bindings.')
def define_common_hparams_flags(): def define_common_hparams_flags():
"""Define the common flags across models.""" """Define the common flags across models."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment