main.py 9.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main function to train various object detection models."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import functools
import pprint
24
25

# pylint: disable=g-bad-import-order
Hongkun Yu's avatar
Hongkun Yu committed
26
# Import libraries
27
import tensorflow as tf
28

29
30
31
32
33
from absl import app
from absl import flags
from absl import logging
# pylint: enable=g-bad-import-order

34
35
from official.modeling.hyperparams import params_dict
from official.modeling.training import distributed_executor as executor
Allen Wang's avatar
Allen Wang committed
36
from official.utils import hyperparams_flags
37
38
39
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
40
41
42
43
44
45
from official.vision.detection.configs import factory as config_factory
from official.vision.detection.dataloader import input_reader
from official.vision.detection.dataloader import mode_keys as ModeKeys
from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor
from official.vision.detection.modeling import factory as model_factory

Allen Wang's avatar
Allen Wang committed
46
hyperparams_flags.initialize_common_flags()
Will Cromar's avatar
Will Cromar committed
47
flags_core.define_log_steps()
48

Yeqing Li's avatar
Yeqing Li committed
49
flags.DEFINE_bool('enable_xla', default=False, help='Enable XLA for GPU')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
50

51
flags.DEFINE_string(
Yeqing Li's avatar
Yeqing Li committed
52
    'mode', default='train', help='Mode to run: `train` or `eval`.')
53
54
55

flags.DEFINE_string(
    'model', default='retinanet',
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
56
    help='Model to run: `retinanet`, `mask_rcnn` or `shapemask`.')
57
58
59
60
61
62

flags.DEFINE_string('training_file_pattern', None,
                    'Location of the train data.')

flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')

Yeqing Li's avatar
Yeqing Li committed
63
64
65
flags.DEFINE_string(
    'checkpoint_path', None,
    'The checkpoint path to eval. Only used in eval_once mode.')
66
67
68
69

FLAGS = flags.FLAGS


70
def run_executor(params,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
71
72
                 mode,
                 checkpoint_path=None,
73
74
                 train_input_fn=None,
                 eval_input_fn=None,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
75
                 callbacks=None,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
76
                 prebuilt_strategy=None):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
77
  """Runs the object detection model on distribution strategy defined by the user."""
78

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
79
80
81
82
83
  if params.architecture.use_bfloat16:
    policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
        'mixed_bfloat16')
    tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

84
85
  model_builder = model_factory.model_generator(params)

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
86
87
88
  if prebuilt_strategy is not None:
    strategy = prebuilt_strategy
  else:
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
89
90
91
92
93
94
95
96
97
98
99
100
101
    strategy_config = params.strategy_config
    distribution_utils.configure_cluster(strategy_config.worker_hosts,
                                         strategy_config.task_index)
    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=params.strategy_type,
        num_gpus=strategy_config.num_gpus,
        all_reduce_alg=strategy_config.all_reduce_alg,
        num_packs=strategy_config.num_packs,
        tpu_address=strategy_config.tpu)

  num_workers = int(strategy.num_replicas_in_sync + 7) // 8
  is_multi_host = (int(num_workers) >= 2)

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
102
  if mode == 'train':
103
104
105
106

    def _model_fn(params):
      return model_builder.build_model(params, mode=ModeKeys.TRAIN)

Yeqing Li's avatar
Yeqing Li committed
107
108
    logging.info(
        'Train num_replicas_in_sync %d num_workers %d is_multi_host %s',
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
109
        strategy.num_replicas_in_sync, num_workers, is_multi_host)
110

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
111
112
    dist_executor = DetectionDistributedExecutor(
        strategy=strategy,
113
114
115
        params=params,
        model_fn=_model_fn,
        loss_fn=model_builder.build_loss_fn,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
116
        is_multi_host=is_multi_host,
117
118
119
120
        predict_post_process_fn=model_builder.post_processing,
        trainable_variables_filter=model_builder
        .make_filter_trainable_variables_fn())

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
121
122
123
124
125
    if is_multi_host:
      train_input_fn = functools.partial(
          train_input_fn,
          batch_size=params.train.batch_size // strategy.num_replicas_in_sync)

126
127
128
129
130
131
    return dist_executor.train(
        train_input_fn=train_input_fn,
        model_dir=params.model_dir,
        iterations_per_loop=params.train.iterations_per_loop,
        total_steps=params.train.total_steps,
        init_checkpoint=model_builder.make_restore_checkpoint_fn(),
132
        custom_callbacks=callbacks,
133
        save_config=True)
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
134
  elif mode == 'eval' or mode == 'eval_once':
135
136
137
138

    def _model_fn(params):
      return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT)

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
139
140
141
    logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s',
                 strategy.num_replicas_in_sync, num_workers, is_multi_host)

Yeqing Li's avatar
Yeqing Li committed
142
143
144
    if is_multi_host:
      eval_input_fn = functools.partial(
          eval_input_fn,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
145
146
147
148
          batch_size=params.eval.batch_size // strategy.num_replicas_in_sync)

    dist_executor = DetectionDistributedExecutor(
        strategy=strategy,
149
150
151
        params=params,
        model_fn=_model_fn,
        loss_fn=model_builder.build_loss_fn,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
152
        is_multi_host=is_multi_host,
153
154
155
156
        predict_post_process_fn=model_builder.post_processing,
        trainable_variables_filter=model_builder
        .make_filter_trainable_variables_fn())

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
157
    if mode == 'eval':
Yeqing Li's avatar
Yeqing Li committed
158
159
160
161
162
163
164
165
166
      results = dist_executor.evaluate_from_model_dir(
          model_dir=params.model_dir,
          eval_input_fn=eval_input_fn,
          eval_metric_fn=model_builder.eval_metrics,
          eval_timeout=params.eval.eval_timeout,
          min_eval_interval=params.eval.min_eval_interval,
          total_steps=params.train.total_steps)
    else:
      # Run evaluation once for a single checkpoint.
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
167
168
      if not checkpoint_path:
        raise ValueError('checkpoint_path cannot be empty.')
Yeqing Li's avatar
Yeqing Li committed
169
170
171
172
173
174
175
176
      if tf.io.gfile.isdir(checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
      summary_writer = executor.SummaryWriter(params.model_dir, 'eval')
      results, _ = dist_executor.evaluate_checkpoint(
          checkpoint_path=checkpoint_path,
          eval_input_fn=eval_input_fn,
          eval_metric_fn=model_builder.eval_metrics,
          summary_writer=summary_writer)
177
178
179
180
    for k, v in results.items():
      logging.info('Final eval metric %s: %f', k, v)
    return results
  else:
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
181
    raise ValueError('Mode not found: %s.' % mode)
182
183


184
def run(callbacks=None):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
185
186
  keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)

187
188
189
190
191
192
193
194
195
196
197
198
199
200
  params = config_factory.config_generator(FLAGS.model)

  params = params_dict.override_params_dict(
      params, FLAGS.config_file, is_strict=True)

  params = params_dict.override_params_dict(
      params, FLAGS.params_override, is_strict=True)
  params.override(
      {
          'strategy_type': FLAGS.strategy_type,
          'model_dir': FLAGS.model_dir,
          'strategy_config': executor.strategy_flags_dict(),
      },
      is_strict=False)
201
202
203
204
205
206
207
208
209
210
211
212
213
214

  # Make sure use_tpu and strategy_type are in sync.
  params.use_tpu = (params.strategy_type == 'tpu')

  if not params.use_tpu:
    params.override({
        'architecture': {
            'use_bfloat16': False,
        },
        'norm_activation': {
            'use_sync_bn': False,
        },
    }, is_strict=True)

215
216
217
218
  params.validate()
  params.lock()
  pp = pprint.PrettyPrinter()
  params_str = pp.pformat(params.as_dict())
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
219
  logging.info('Model Parameters: %s', params_str)
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

  train_input_fn = None
  eval_input_fn = None
  training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern
  eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern
  if not training_file_pattern and not eval_file_pattern:
    raise ValueError('Must provide at least one of training_file_pattern and '
                     'eval_file_pattern.')

  if training_file_pattern:
    # Use global batch size for single host.
    train_input_fn = input_reader.InputFn(
        file_pattern=training_file_pattern,
        params=params,
        mode=input_reader.ModeKeys.TRAIN,
        batch_size=params.train.batch_size)

  if eval_file_pattern:
    eval_input_fn = input_reader.InputFn(
        file_pattern=eval_file_pattern,
        params=params,
        mode=input_reader.ModeKeys.PREDICT_WITH_GT,
        batch_size=params.eval.batch_size,
        num_examples=params.eval.eval_samples)
Will Cromar's avatar
Will Cromar committed
244
245
246
247
248
249
250
251
252
253
254

  if callbacks is None:
    callbacks = []

  if FLAGS.log_steps:
    callbacks.append(
        keras_utils.TimeHistory(
            batch_size=params.train.batch_size,
            log_steps=FLAGS.log_steps,
        ))

255
  return run_executor(
256
      params,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
257
258
      FLAGS.mode,
      checkpoint_path=FLAGS.checkpoint_path,
259
260
261
262
263
264
265
266
      train_input_fn=train_input_fn,
      eval_input_fn=eval_input_fn,
      callbacks=callbacks)


def main(argv):
  del argv  # Unused.

Yeqing Li's avatar
Yeqing Li committed
267
  run()
268
269
270


if __name__ == '__main__':
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
271
  tf.config.set_soft_device_placement(True)
272
  app.run(main)