_base.py 7.03 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags which will be nearly universal across models."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import flags
36
import tensorflow as tf
37
38
39
from official.utils.flags._conventions import help_wrap


Hongkun Yu's avatar
Hongkun Yu committed
40
41
42
43
44
45
46
47
48
49
50
51
def define_base(data_dir=True,
                model_dir=True,
                clean=False,
                train_epochs=False,
                epochs_between_evals=False,
                stop_threshold=False,
                batch_size=True,
                num_gpu=False,
                hooks=False,
                export_dir=False,
                distribution_strategy=False,
                run_eagerly=False):
52
53
54
55
56
  """Register base flags.

  Args:
    data_dir: Create a flag for specifying the input data directory.
    model_dir: Create a flag for specifying the model file directory.
57
    clean: Create a flag for removing the model_dir.
58
59
    train_epochs: Create a flag to specify the number of training epochs.
    epochs_between_evals: Create a flag to specify the frequency of testing.
Hongkun Yu's avatar
Hongkun Yu committed
60
61
    stop_threshold: Create a flag to specify a threshold accuracy or other eval
      metric which should trigger the end of training.
62
    batch_size: Create a flag to specify the batch size.
63
    num_gpu: Create a flag to specify the number of GPUs used.
64
65
    hooks: Create a flag to specify hooks for logging.
    export_dir: Create a flag to specify where a SavedModel should be exported.
66
67
    distribution_strategy: Create a flag to specify which Distribution Strategy
      to use.
68
    run_eagerly: Create a flag to specify to run eagerly op by op.
Hongkun Yu's avatar
Hongkun Yu committed
69

70
71
72
73
74
75
76
  Returns:
    A list of flags for core.py to marks as key flags.
  """
  key_flags = []

  if data_dir:
    flags.DEFINE_string(
Hongkun Yu's avatar
Hongkun Yu committed
77
78
79
        name="data_dir",
        short_name="dd",
        default="/tmp",
80
81
82
83
84
        help=help_wrap("The location of the input data."))
    key_flags.append("data_dir")

  if model_dir:
    flags.DEFINE_string(
Hongkun Yu's avatar
Hongkun Yu committed
85
86
87
        name="model_dir",
        short_name="md",
        default="/tmp",
88
89
90
        help=help_wrap("The location of the model checkpoint files."))
    key_flags.append("model_dir")

91
92
  if clean:
    flags.DEFINE_boolean(
Hongkun Yu's avatar
Hongkun Yu committed
93
94
        name="clean",
        default=False,
95
96
97
        help=help_wrap("If set, model_dir will be removed if it exists."))
    key_flags.append("clean")

98
99
  if train_epochs:
    flags.DEFINE_integer(
Hongkun Yu's avatar
Hongkun Yu committed
100
101
102
        name="train_epochs",
        short_name="te",
        default=1,
103
104
105
106
107
        help=help_wrap("The number of epochs used to train."))
    key_flags.append("train_epochs")

  if epochs_between_evals:
    flags.DEFINE_integer(
Hongkun Yu's avatar
Hongkun Yu committed
108
109
110
        name="epochs_between_evals",
        short_name="ebe",
        default=1,
111
112
113
114
115
116
        help=help_wrap("The number of training epochs to run between "
                       "evaluations."))
    key_flags.append("epochs_between_evals")

  if stop_threshold:
    flags.DEFINE_float(
Hongkun Yu's avatar
Hongkun Yu committed
117
118
        name="stop_threshold",
        short_name="st",
119
120
121
122
123
124
125
        default=None,
        help=help_wrap("If passed, training will stop at the earlier of "
                       "train_epochs and when the evaluation metric is  "
                       "greater than or equal to stop_threshold."))

  if batch_size:
    flags.DEFINE_integer(
Hongkun Yu's avatar
Hongkun Yu committed
126
127
128
        name="batch_size",
        short_name="bs",
        default=32,
129
130
131
132
133
        help=help_wrap("Batch size for training and evaluation. When using "
                       "multiple gpus, this is the global batch size for "
                       "all devices. For example, if the batch size is 32 "
                       "and there are 4 GPUs, each GPU will get 8 examples on "
                       "each step."))
134
135
    key_flags.append("batch_size")

136
137
  if num_gpu:
    flags.DEFINE_integer(
Hongkun Yu's avatar
Hongkun Yu committed
138
139
        name="num_gpus",
        short_name="ng",
140
        default=1,
Hongkun Yu's avatar
Hongkun Yu committed
141
142
        help=help_wrap("How many GPUs to use at each worker with the "
                       "DistributionStrategies API. The default is 1."))
143

144
145
  if run_eagerly:
    flags.DEFINE_boolean(
Hongkun Yu's avatar
Hongkun Yu committed
146
147
        name="run_eagerly",
        default=False,
148
149
        help="Run the model op by op without building a model function.")

150
151
  if hooks:
    flags.DEFINE_list(
Hongkun Yu's avatar
Hongkun Yu committed
152
153
154
        name="hooks",
        short_name="hk",
        default="LoggingTensorHook",
155
156
        help=help_wrap(
            u"A list of (case insensitive) strings to specify the names of "
157
158
            u"training hooks. Example: `--hooks ProfilerHook,"
            u"ExamplesPerSecondHook`\n See hooks_helper "
Hongkun Yu's avatar
Hongkun Yu committed
159
            u"for details."))
160
161
162
163
    key_flags.append("hooks")

  if export_dir:
    flags.DEFINE_string(
Hongkun Yu's avatar
Hongkun Yu committed
164
165
166
        name="export_dir",
        short_name="ed",
        default=None,
167
168
        help=help_wrap("If set, a SavedModel serialization of the model will "
                       "be exported to this directory at the end of training. "
Hongkun Yu's avatar
Hongkun Yu committed
169
                       "See the README for more details and relevant links."))
170
171
    key_flags.append("export_dir")

172
173
  if distribution_strategy:
    flags.DEFINE_string(
Hongkun Yu's avatar
Hongkun Yu committed
174
175
176
        name="distribution_strategy",
        short_name="ds",
        default="mirrored",
177
        help=help_wrap("The Distribution Strategy to use for training. "
178
                       "Accepted values are 'off', 'one_device', "
179
180
181
182
                       "'mirrored', 'parameter_server', 'collective', "
                       "case insensitive. 'off' means not to use "
                       "Distribution Strategy; 'default' means to choose "
                       "from `MirroredStrategy` or `OneDeviceStrategy` "
Hongkun Yu's avatar
Hongkun Yu committed
183
                       "according to the number of GPUs."))
184

185
  return key_flags
186
187
188
189
190
191
192
193
194
195


def get_num_gpus(flags_obj):
  """Treat num_gpus=-1 as 'use all'."""
  if flags_obj.num_gpus != -1:
    return flags_obj.num_gpus

  from tensorflow.python.client import device_lib  # pylint: disable=g-import-not-at-top
  local_device_protos = device_lib.list_local_devices()
  return sum([1 for d in local_device_protos if d.device_type == "GPU"])