main.py 9.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main function to train various object detection models."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import functools
import pprint
24
25

# pylint: disable=g-bad-import-order
26
import tensorflow as tf
27

28
29
30
31
32
from absl import app
from absl import flags
from absl import logging
# pylint: enable=g-bad-import-order

33
34
from official.modeling.hyperparams import params_dict
from official.modeling.training import distributed_executor as executor
Allen Wang's avatar
Allen Wang committed
35
from official.utils import hyperparams_flags
36
37
38
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
39
40
41
42
43
44
from official.vision.detection.configs import factory as config_factory
from official.vision.detection.dataloader import input_reader
from official.vision.detection.dataloader import mode_keys as ModeKeys
from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor
from official.vision.detection.modeling import factory as model_factory

Allen Wang's avatar
Allen Wang committed
45
hyperparams_flags.initialize_common_flags()
Will Cromar's avatar
Will Cromar committed
46
flags_core.define_log_steps()
47

Yeqing Li's avatar
Yeqing Li committed
48
flags.DEFINE_bool('enable_xla', default=False, help='Enable XLA for GPU')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
49

50
flags.DEFINE_string(
Yeqing Li's avatar
Yeqing Li committed
51
    'mode', default='train', help='Mode to run: `train` or `eval`.')
52
53
54

flags.DEFINE_string(
    'model', default='retinanet',
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
55
    help='Model to run: `retinanet`, `mask_rcnn` or `shapemask`.')
56
57
58
59
60
61

flags.DEFINE_string('training_file_pattern', None,
                    'Location of the train data.')

flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')

Yeqing Li's avatar
Yeqing Li committed
62
63
64
flags.DEFINE_string(
    'checkpoint_path', None,
    'The checkpoint path to eval. Only used in eval_once mode.')
65
66
67
68

FLAGS = flags.FLAGS


69
def run_executor(params,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
70
71
                 mode,
                 checkpoint_path=None,
72
73
                 train_input_fn=None,
                 eval_input_fn=None,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
74
                 callbacks=None,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
75
                 prebuilt_strategy=None):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
76
  """Runs the object detection model on distribution strategy defined by the user."""
77

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
78
79
80
81
82
  if params.architecture.use_bfloat16:
    policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
        'mixed_bfloat16')
    tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

83
84
  model_builder = model_factory.model_generator(params)

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
85
86
87
  if prebuilt_strategy is not None:
    strategy = prebuilt_strategy
  else:
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
88
89
90
91
92
93
94
95
96
97
98
99
100
    strategy_config = params.strategy_config
    distribution_utils.configure_cluster(strategy_config.worker_hosts,
                                         strategy_config.task_index)
    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy=params.strategy_type,
        num_gpus=strategy_config.num_gpus,
        all_reduce_alg=strategy_config.all_reduce_alg,
        num_packs=strategy_config.num_packs,
        tpu_address=strategy_config.tpu)

  num_workers = int(strategy.num_replicas_in_sync + 7) // 8
  is_multi_host = (int(num_workers) >= 2)

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
101
  if mode == 'train':
102
103
104
105

    def _model_fn(params):
      return model_builder.build_model(params, mode=ModeKeys.TRAIN)

Yeqing Li's avatar
Yeqing Li committed
106
107
    logging.info(
        'Train num_replicas_in_sync %d num_workers %d is_multi_host %s',
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
108
        strategy.num_replicas_in_sync, num_workers, is_multi_host)
109

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
110
111
    dist_executor = DetectionDistributedExecutor(
        strategy=strategy,
112
113
114
        params=params,
        model_fn=_model_fn,
        loss_fn=model_builder.build_loss_fn,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
115
        is_multi_host=is_multi_host,
116
117
118
119
        predict_post_process_fn=model_builder.post_processing,
        trainable_variables_filter=model_builder
        .make_filter_trainable_variables_fn())

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
120
121
122
123
124
    if is_multi_host:
      train_input_fn = functools.partial(
          train_input_fn,
          batch_size=params.train.batch_size // strategy.num_replicas_in_sync)

125
126
127
128
129
130
    return dist_executor.train(
        train_input_fn=train_input_fn,
        model_dir=params.model_dir,
        iterations_per_loop=params.train.iterations_per_loop,
        total_steps=params.train.total_steps,
        init_checkpoint=model_builder.make_restore_checkpoint_fn(),
131
        custom_callbacks=callbacks,
132
        save_config=True)
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
133
  elif mode == 'eval' or mode == 'eval_once':
134
135
136
137

    def _model_fn(params):
      return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT)

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
138
139
140
    logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s',
                 strategy.num_replicas_in_sync, num_workers, is_multi_host)

Yeqing Li's avatar
Yeqing Li committed
141
142
143
    if is_multi_host:
      eval_input_fn = functools.partial(
          eval_input_fn,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
144
145
146
147
          batch_size=params.eval.batch_size // strategy.num_replicas_in_sync)

    dist_executor = DetectionDistributedExecutor(
        strategy=strategy,
148
149
150
        params=params,
        model_fn=_model_fn,
        loss_fn=model_builder.build_loss_fn,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
151
        is_multi_host=is_multi_host,
152
153
154
155
        predict_post_process_fn=model_builder.post_processing,
        trainable_variables_filter=model_builder
        .make_filter_trainable_variables_fn())

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
156
    if mode == 'eval':
Yeqing Li's avatar
Yeqing Li committed
157
158
159
160
161
162
163
164
165
      results = dist_executor.evaluate_from_model_dir(
          model_dir=params.model_dir,
          eval_input_fn=eval_input_fn,
          eval_metric_fn=model_builder.eval_metrics,
          eval_timeout=params.eval.eval_timeout,
          min_eval_interval=params.eval.min_eval_interval,
          total_steps=params.train.total_steps)
    else:
      # Run evaluation once for a single checkpoint.
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
166
167
      if not checkpoint_path:
        raise ValueError('checkpoint_path cannot be empty.')
Yeqing Li's avatar
Yeqing Li committed
168
169
170
171
172
173
174
175
      if tf.io.gfile.isdir(checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
      summary_writer = executor.SummaryWriter(params.model_dir, 'eval')
      results, _ = dist_executor.evaluate_checkpoint(
          checkpoint_path=checkpoint_path,
          eval_input_fn=eval_input_fn,
          eval_metric_fn=model_builder.eval_metrics,
          summary_writer=summary_writer)
176
177
178
179
    for k, v in results.items():
      logging.info('Final eval metric %s: %f', k, v)
    return results
  else:
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
180
    raise ValueError('Mode not found: %s.' % mode)
181
182


183
def run(callbacks=None):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
184
185
  keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)

186
187
188
189
190
191
192
193
194
195
196
197
198
199
  params = config_factory.config_generator(FLAGS.model)

  params = params_dict.override_params_dict(
      params, FLAGS.config_file, is_strict=True)

  params = params_dict.override_params_dict(
      params, FLAGS.params_override, is_strict=True)
  params.override(
      {
          'strategy_type': FLAGS.strategy_type,
          'model_dir': FLAGS.model_dir,
          'strategy_config': executor.strategy_flags_dict(),
      },
      is_strict=False)
200
201
202
203
204
205
206
207
208
209
210
211
212
213

  # Make sure use_tpu and strategy_type are in sync.
  params.use_tpu = (params.strategy_type == 'tpu')

  if not params.use_tpu:
    params.override({
        'architecture': {
            'use_bfloat16': False,
        },
        'norm_activation': {
            'use_sync_bn': False,
        },
    }, is_strict=True)

214
215
216
217
  params.validate()
  params.lock()
  pp = pprint.PrettyPrinter()
  params_str = pp.pformat(params.as_dict())
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
218
  logging.info('Model Parameters: %s', params_str)
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

  train_input_fn = None
  eval_input_fn = None
  training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern
  eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern
  if not training_file_pattern and not eval_file_pattern:
    raise ValueError('Must provide at least one of training_file_pattern and '
                     'eval_file_pattern.')

  if training_file_pattern:
    # Use global batch size for single host.
    train_input_fn = input_reader.InputFn(
        file_pattern=training_file_pattern,
        params=params,
        mode=input_reader.ModeKeys.TRAIN,
        batch_size=params.train.batch_size)

  if eval_file_pattern:
    eval_input_fn = input_reader.InputFn(
        file_pattern=eval_file_pattern,
        params=params,
        mode=input_reader.ModeKeys.PREDICT_WITH_GT,
        batch_size=params.eval.batch_size,
        num_examples=params.eval.eval_samples)
Will Cromar's avatar
Will Cromar committed
243
244
245
246
247
248
249
250
251
252
253

  if callbacks is None:
    callbacks = []

  if FLAGS.log_steps:
    callbacks.append(
        keras_utils.TimeHistory(
            batch_size=params.train.batch_size,
            log_steps=FLAGS.log_steps,
        ))

254
  return run_executor(
255
      params,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
256
257
      FLAGS.mode,
      checkpoint_path=FLAGS.checkpoint_path,
258
259
260
261
262
263
264
265
      train_input_fn=train_input_fn,
      eval_input_fn=eval_input_fn,
      callbacks=callbacks)


def main(argv):
  del argv  # Unused.

Yeqing Li's avatar
Yeqing Li committed
266
  run()
267
268
269


if __name__ == '__main__':
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
270
  tf.config.set_soft_device_placement(True)
271
  app.run(main)