common_flags.py 3.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defining common flags used across all BERT models/applications."""

from absl import flags
18
19
20
import tensorflow as tf

from official.utils.flags import core as flags_core
21
22
23
24
25
26
27
28
29


def define_common_bert_flags():
  """Define the flags related to TPU's."""
  flags.DEFINE_string('bert_config_file', None,
                      'Bert configuration file to define core bert layers.')
  flags.DEFINE_string('model_dir', None, (
      'The directory where the model weights and training/evaluation summaries '
      'are stored. If not specified, save to /tmp/bert20/.'))
Hongkun Yu's avatar
Hongkun Yu committed
30
31
32
33
  flags.DEFINE_string(
      'model_export_path', None,
      'Path to the directory, where trainined model will be '
      'exported.')
34
35
36
37
38
  flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
  flags.DEFINE_string(
      'init_checkpoint', None,
      'Initial checkpoint (usually from a pre-trained BERT model).')
  flags.DEFINE_enum(
39
      'strategy_type', 'mirror', ['tpu', 'mirror', 'multi_worker_mirror'],
40
      'Distribution Strategy type to use for training. `tpu` uses '
41
42
      'TPUStrategy for running on TPUs, `mirror` uses GPUs with single host, '
      '`multi_worker_mirror` uses CPUs or GPUs with multiple hosts.')
43
44
45
46
47
48
49
50
51
  flags.DEFINE_integer('num_train_epochs', 3,
                       'Total number of training epochs to perform.')
  flags.DEFINE_integer(
      'steps_per_loop', 200,
      'Number of steps per graph-mode loop. Only training step '
      'happens inside the loop. Callbacks will not be called '
      'inside.')
  flags.DEFINE_float('learning_rate', 5e-5,
                     'The initial learning rate for Adam.')
52
53
54
  flags.DEFINE_boolean(
      'run_eagerly', False,
      'Run the model op by op without building a model function.')
55
56
57
58
  flags.DEFINE_boolean(
      'scale_loss', False,
      'Whether to divide the loss by number of replica inside the per-replica '
      'loss function.')
59

60
  # Adds flags for mixed precision training.
61
62
63
64
65
66
67
68
69
70
71
  flags_core.define_performance(
      num_parallel_calls=False,
      inter_op=False,
      intra_op=False,
      synthetic_data=False,
      max_train_steps=False,
      dtype=True,
      dynamic_loss_scale=True,
      loss_scale=True,
      all_reduce_alg=False,
      num_packs=False,
72
73
      enable_xla=True,
      fp16_implementation=True,
74
75
76
77
78
79
80
81
82
  )


def use_float16():
  return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16


def get_loss_scale():
  return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')