Commit d93a1206 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Move gin flags to hyperparams_flags are they are in the same category and we...

Move gin flags to hyperparams_flags are they are in the same category and we will use them more widely.

PiperOrigin-RevId: 309779408
parent e91779b8
......@@ -17,17 +17,10 @@
from absl import flags
import tensorflow as tf
from official.utils import hyperparams_flags
from official.utils.flags import core as flags_core
def define_gin_flags():
"""Define common gin configurable flags."""
flags.DEFINE_multi_string('gin_file', None,
'List of paths to the config files.')
flags.DEFINE_multi_string(
'gin_param', None, 'Newline separated list of Gin parameter bindings.')
def define_common_bert_flags():
"""Define common flags for BERT tasks."""
flags_core.define_base(
......@@ -100,6 +93,9 @@ def define_common_bert_flags():
fp16_implementation=True,
)
# Adds gin configuration flags.
hyperparams_flags.define_gin_flags()
def dtype():
return flags_core.get_tf_dtype(flags.FLAGS)
......
......@@ -57,7 +57,6 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
FLAGS = flags.FLAGS
......
......@@ -51,7 +51,6 @@ flags.DEFINE_bool('use_next_sentence_label', True,
'Whether to use next sentence label to compute final loss.')
common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
FLAGS = flags.FLAGS
......
......@@ -88,7 +88,6 @@ def define_common_squad_flags():
'another.')
common_flags.define_common_bert_flags()
common_flags.define_gin_flags()
FLAGS = flags.FLAGS
......
......@@ -25,6 +25,14 @@ from official.utils.flags import core as flags_core
FLAGS = flags.FLAGS
def define_gin_flags():
"""Define common gin configurable flags."""
flags.DEFINE_multi_string('gin_file', None,
'List of paths to the config files.')
flags.DEFINE_multi_string(
'gin_param', None, 'Newline separated list of Gin parameter bindings.')
def define_common_hparams_flags():
"""Define the common flags across models."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment