_base.py 6.25 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
17
"""Flags which will be nearly universal across models."""

from absl import flags
18
import tensorflow as tf
19
20
21
from official.utils.flags._conventions import help_wrap


Hongkun Yu's avatar
Hongkun Yu committed
22
23
24
25
26
27
28
29
30
31
32
33
def define_base(data_dir=True,
                model_dir=True,
                clean=False,
                train_epochs=False,
                epochs_between_evals=False,
                stop_threshold=False,
                batch_size=True,
                num_gpu=False,
                hooks=False,
                export_dir=False,
                distribution_strategy=False,
                run_eagerly=False):
34
35
36
37
38
  """Register base flags.

  Args:
    data_dir: Create a flag for specifying the input data directory.
    model_dir: Create a flag for specifying the model file directory.
39
    clean: Create a flag for removing the model_dir.
40
41
    train_epochs: Create a flag to specify the number of training epochs.
    epochs_between_evals: Create a flag to specify the frequency of testing.
Hongkun Yu's avatar
Hongkun Yu committed
42
43
    stop_threshold: Create a flag to specify a threshold accuracy or other eval
      metric which should trigger the end of training.
44
    batch_size: Create a flag to specify the batch size.
45
    num_gpu: Create a flag to specify the number of GPUs used.
46
47
    hooks: Create a flag to specify hooks for logging.
    export_dir: Create a flag to specify where a SavedModel should be exported.
48
49
    distribution_strategy: Create a flag to specify which Distribution Strategy
      to use.
50
    run_eagerly: Create a flag to specify to run eagerly op by op.
Hongkun Yu's avatar
Hongkun Yu committed
51

52
53
54
55
56
57
58
  Returns:
    A list of flags for core.py to marks as key flags.
  """
  key_flags = []

  if data_dir:
    flags.DEFINE_string(
Hongkun Yu's avatar
Hongkun Yu committed
59
60
61
        name="data_dir",
        short_name="dd",
        default="/tmp",
62
63
64
65
66
        help=help_wrap("The location of the input data."))
    key_flags.append("data_dir")

  if model_dir:
    flags.DEFINE_string(
Hongkun Yu's avatar
Hongkun Yu committed
67
68
69
        name="model_dir",
        short_name="md",
        default="/tmp",
70
71
72
        help=help_wrap("The location of the model checkpoint files."))
    key_flags.append("model_dir")

73
74
  if clean:
    flags.DEFINE_boolean(
Hongkun Yu's avatar
Hongkun Yu committed
75
76
        name="clean",
        default=False,
77
78
79
        help=help_wrap("If set, model_dir will be removed if it exists."))
    key_flags.append("clean")

80
81
  if train_epochs:
    flags.DEFINE_integer(
Hongkun Yu's avatar
Hongkun Yu committed
82
83
84
        name="train_epochs",
        short_name="te",
        default=1,
85
86
87
88
89
        help=help_wrap("The number of epochs used to train."))
    key_flags.append("train_epochs")

  if epochs_between_evals:
    flags.DEFINE_integer(
Hongkun Yu's avatar
Hongkun Yu committed
90
91
92
        name="epochs_between_evals",
        short_name="ebe",
        default=1,
93
94
95
96
97
98
        help=help_wrap("The number of training epochs to run between "
                       "evaluations."))
    key_flags.append("epochs_between_evals")

  if stop_threshold:
    flags.DEFINE_float(
Hongkun Yu's avatar
Hongkun Yu committed
99
100
        name="stop_threshold",
        short_name="st",
101
102
103
104
105
106
107
        default=None,
        help=help_wrap("If passed, training will stop at the earlier of "
                       "train_epochs and when the evaluation metric is  "
                       "greater than or equal to stop_threshold."))

  if batch_size:
    flags.DEFINE_integer(
Hongkun Yu's avatar
Hongkun Yu committed
108
109
110
        name="batch_size",
        short_name="bs",
        default=32,
111
112
113
114
115
        help=help_wrap("Batch size for training and evaluation. When using "
                       "multiple gpus, this is the global batch size for "
                       "all devices. For example, if the batch size is 32 "
                       "and there are 4 GPUs, each GPU will get 8 examples on "
                       "each step."))
116
117
    key_flags.append("batch_size")

118
119
  if num_gpu:
    flags.DEFINE_integer(
Hongkun Yu's avatar
Hongkun Yu committed
120
121
        name="num_gpus",
        short_name="ng",
122
        default=1,
Hongkun Yu's avatar
Hongkun Yu committed
123
124
        help=help_wrap("How many GPUs to use at each worker with the "
                       "DistributionStrategies API. The default is 1."))
125

126
127
  if run_eagerly:
    flags.DEFINE_boolean(
Hongkun Yu's avatar
Hongkun Yu committed
128
129
        name="run_eagerly",
        default=False,
130
131
        help="Run the model op by op without building a model function.")

132
133
  if hooks:
    flags.DEFINE_list(
Hongkun Yu's avatar
Hongkun Yu committed
134
135
136
        name="hooks",
        short_name="hk",
        default="LoggingTensorHook",
137
138
        help=help_wrap(
            u"A list of (case insensitive) strings to specify the names of "
139
140
            u"training hooks. Example: `--hooks ProfilerHook,"
            u"ExamplesPerSecondHook`\n See hooks_helper "
Hongkun Yu's avatar
Hongkun Yu committed
141
            u"for details."))
142
143
144
145
    key_flags.append("hooks")

  if export_dir:
    flags.DEFINE_string(
Hongkun Yu's avatar
Hongkun Yu committed
146
147
148
        name="export_dir",
        short_name="ed",
        default=None,
149
150
        help=help_wrap("If set, a SavedModel serialization of the model will "
                       "be exported to this directory at the end of training. "
Hongkun Yu's avatar
Hongkun Yu committed
151
                       "See the README for more details and relevant links."))
152
153
    key_flags.append("export_dir")

154
155
  if distribution_strategy:
    flags.DEFINE_string(
Hongkun Yu's avatar
Hongkun Yu committed
156
157
158
        name="distribution_strategy",
        short_name="ds",
        default="mirrored",
159
        help=help_wrap("The Distribution Strategy to use for training. "
160
                       "Accepted values are 'off', 'one_device', "
161
162
163
164
                       "'mirrored', 'parameter_server', 'collective', "
                       "case insensitive. 'off' means not to use "
                       "Distribution Strategy; 'default' means to choose "
                       "from `MirroredStrategy` or `OneDeviceStrategy` "
Hongkun Yu's avatar
Hongkun Yu committed
165
                       "according to the number of GPUs."))
166

167
  return key_flags
168
169
170
171
172
173
174
175
176
177


def get_num_gpus(flags_obj):
  """Treat num_gpus=-1 as 'use all'."""
  if flags_obj.num_gpus != -1:
    return flags_obj.num_gpus

  from tensorflow.python.client import device_lib  # pylint: disable=g-import-not-at-top
  local_device_protos = device_lib.list_local_devices()
  return sum([1 for d in local_device_protos if d.device_type == "GPU"])