"src/targets/gpu/jit/mlir.cpp" did not exist on "2783c6494854e5fd3e70b87fd6f64256a4cd37de"
train.py 19.1 KB
Newer Older
yukun's avatar
yukun committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training script for the DeepLab model.

See model.py for more details and usage.
"""

20
import six
yukun's avatar
yukun committed
21
import tensorflow as tf
22
from tensorflow.python.ops import math_ops
yukun's avatar
yukun committed
23
24
from deeplab import common
from deeplab import model
25
from deeplab.datasets import data_generator
yukun's avatar
yukun committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from deeplab.utils import train_utils

flags = tf.app.flags
FLAGS = flags.FLAGS

# Settings for multi-GPUs/multi-replicas training.

flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy.')

flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones.')

flags.DEFINE_integer('num_replicas', 1, 'Number of worker replicas.')

flags.DEFINE_integer('startup_delay_steps', 15,
                     'Number of training steps between replicas startup.')

42
43
44
45
flags.DEFINE_integer(
    'num_ps_tasks', 0,
    'The number of parameter servers. If the value is 0, then '
    'the parameters are handled locally by the worker.')
yukun's avatar
yukun committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')

flags.DEFINE_integer('task', 0, 'The task ID.')

# Settings for logging.

flags.DEFINE_string('train_logdir', None,
                    'Where the checkpoint and logs are stored.')

flags.DEFINE_integer('log_steps', 10,
                     'Display logging information at every log_steps.')

flags.DEFINE_integer('save_interval_secs', 1200,
                     'How often, in seconds, we save the model to disk.')

flags.DEFINE_integer('save_summaries_secs', 600,
                     'How often, in seconds, we compute the summaries.')

65
66
67
68
69
70
71
72
73
flags.DEFINE_boolean(
    'save_summaries_images', False,
    'Save sample inputs, labels, and semantic predictions as '
    'images to summary.')

# Settings for profiling.

flags.DEFINE_string('profile_logdir', None,
                    'Where the profile files are stored.')
74

hsm207's avatar
hsm207 committed
75
# Settings for training strategy.
yukun's avatar
yukun committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'],
                  'Learning rate policy for training.')

# Use 0.007 when training on PASCAL augmented training set, train_aug. When
# fine-tuning on PASCAL trainval set, use learning rate=0.0001.
flags.DEFINE_float('base_learning_rate', .0001,
                   'The base learning rate for model training.')

flags.DEFINE_float('learning_rate_decay_factor', 0.1,
                   'The rate to decay the base learning rate.')

flags.DEFINE_integer('learning_rate_decay_step', 2000,
                     'Decay the base learning rate at a fixed step.')

flags.DEFINE_float('learning_power', 0.9,
                   'The power value used in the poly learning policy.')

flags.DEFINE_integer('training_number_of_steps', 30000,
                     'The number of steps used for training')

flags.DEFINE_float('momentum', 0.9, 'The momentum value to use')

# When fine_tune_batch_norm=True, use at least batch size larger than 12
# (batch size more than 16 is better). Otherwise, one could use smaller batch
# size and set fine_tune_batch_norm=False.
flags.DEFINE_integer('train_batch_size', 8,
                     'The number of images in each batch during training.')

105
106
# For weight_decay, use 0.00004 for MobileNet-V2 or Xcpetion model variants.
# Use 0.0001 for ResNet model variants.
yukun's avatar
yukun committed
107
108
109
flags.DEFINE_float('weight_decay', 0.00004,
                   'The value of the weight decay for training.')

110
111
flags.DEFINE_list('train_crop_size', '513,513',
                  'Image crop size [height, width] during training.')
yukun's avatar
yukun committed
112

113
114
115
116
flags.DEFINE_float(
    'last_layer_gradient_multiplier', 1.0,
    'The gradient multiplier for last layers, which is used to '
    'boost the gradient of last layers if the value > 1.')
yukun's avatar
yukun committed
117
118
119
120

flags.DEFINE_boolean('upsample_logits', True,
                     'Upsample logits during training.')

121
122
123
124
125
126
# Hyper-parameters for NAS training strategy.

flags.DEFINE_float(
    'drop_path_keep_prob', 1.0,
    'Probability to keep each path in the NAS cell when training.')

yukun's avatar
yukun committed
127
128
129
130
131
132
133
134
135
# Settings for fine-tuning the network.

flags.DEFINE_string('tf_initial_checkpoint', None,
                    'The initial checkpoint in tensorflow format.')

# Set to False if one does not want to re-use the trained classifier weights.
flags.DEFINE_boolean('initialize_last_layer', True,
                     'Initialize the last layer.')

136
137
138
flags.DEFINE_boolean('last_layers_contain_logits_only', False,
                     'Only consider logits as last layers or not.')

yukun's avatar
yukun committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
flags.DEFINE_integer('slow_start_step', 0,
                     'Training model with small learning rate for few steps.')

flags.DEFINE_float('slow_start_learning_rate', 1e-4,
                   'Learning rate employed during slow start.')

# Set to True if one wants to fine-tune the batch norm parameters in DeepLabv3.
# Set to False and use small batch size to save GPU memory.
flags.DEFINE_boolean('fine_tune_batch_norm', True,
                     'Fine tune the batch norm parameters or not.')

flags.DEFINE_float('min_scale_factor', 0.5,
                   'Mininum scale factor for data augmentation.')

flags.DEFINE_float('max_scale_factor', 2.,
                   'Maximum scale factor for data augmentation.')

flags.DEFINE_float('scale_factor_step_size', 0.25,
                   'Scale factor step size for data augmentation.')

# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
160
161
# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
# one could use different atrous_rates/output_stride during training/evaluation.
yukun's avatar
yukun committed
162
163
164
165
166
167
flags.DEFINE_multi_integer('atrous_rates', None,
                           'Atrous rates for atrous spatial pyramid pooling.')

flags.DEFINE_integer('output_stride', 16,
                     'The ratio of input to output spatial resolution.')

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Hard example mining related flags.
flags.DEFINE_integer(
    'hard_example_mining_step', 0,
    'The training step in which exact hard example mining kicks off. Note we '
    'gradually reduce the mining percent to the specified '
    'top_k_percent_pixels. For example, if hard_example_mining_step=100K and '
    'top_k_percent_pixels=0.25, then mining percent will gradually reduce from '
    '100% to 25% until 100K steps after which we only mine top 25% pixels.')


flags.DEFINE_float(
    'top_k_percent_pixels', 1.0,
    'The top k percent pixels (in terms of the loss values) used to compute '
    'loss during training. This is useful for hard pixel mining.')

183
184
185
186
187
# Quantization setting.
flags.DEFINE_integer(
    'quantize_delay_step', -1,
    'Steps to start quantized training. If < 0, will not quantize model.')

yukun's avatar
yukun committed
188
189
190
191
192
193
194
195
196
197
# Dataset settings.
flags.DEFINE_string('dataset', 'pascal_voc_seg',
                    'Name of the segmentation dataset.')

flags.DEFINE_string('train_split', 'train',
                    'Which split of the dataset to be used for training')

flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.')


198
def _build_deeplab(iterator, outputs_to_num_classes, ignore_label):
yukun's avatar
yukun committed
199
200
201
  """Builds a clone of DeepLab.

  Args:
202
203
204
205
    iterator: An iterator of type tf.data.Iterator for images and labels.
    outputs_to_num_classes: A map from output type to the number of classes. For
      example, for the task of semantic segmentation with 21 semantic classes,
      we would have outputs_to_num_classes['semantic'] = 21.
yukun's avatar
yukun committed
206
207
    ignore_label: Ignore label.
  """
208
  samples = iterator.get_next()
yukun's avatar
yukun committed
209

210
  # Add name to input and label nodes so we can add to summary.
211
212
  samples[common.IMAGE] = tf.identity(samples[common.IMAGE], name=common.IMAGE)
  samples[common.LABEL] = tf.identity(samples[common.LABEL], name=common.LABEL)
213

yukun's avatar
yukun committed
214
215
  model_options = common.ModelOptions(
      outputs_to_num_classes=outputs_to_num_classes,
216
      crop_size=[int(sz) for sz in FLAGS.train_crop_size],
yukun's avatar
yukun committed
217
218
      atrous_rates=FLAGS.atrous_rates,
      output_stride=FLAGS.output_stride)
219

yukun's avatar
yukun committed
220
221
222
223
224
225
  outputs_to_scales_to_logits = model.multi_scale_logits(
      samples[common.IMAGE],
      model_options=model_options,
      image_pyramid=FLAGS.image_pyramid,
      weight_decay=FLAGS.weight_decay,
      is_training=True,
226
227
228
229
230
      fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
      nas_training_hyper_parameters={
          'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
          'total_training_steps': FLAGS.training_number_of_steps,
      })
yukun's avatar
yukun committed
231

232
233
  # Add name to graph node so we can add to summary.
  output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE]
234
  output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity(
235
      output_type_dict[model.MERGED_LOGITS_SCOPE], name=common.OUTPUT_TYPE)
236

237
  for output, num_classes in six.iteritems(outputs_to_num_classes):
yukun's avatar
yukun committed
238
239
240
241
242
243
244
    train_utils.add_softmax_cross_entropy_loss_for_each_scale(
        outputs_to_scales_to_logits[output],
        samples[common.LABEL],
        num_classes,
        ignore_label,
        loss_weight=1.0,
        upsample_logits=FLAGS.upsample_logits,
245
246
        hard_example_mining_step=FLAGS.hard_example_mining_step,
        top_k_percent_pixels=FLAGS.top_k_percent_pixels,
yukun's avatar
yukun committed
247
248
        scope=output)

249
250
251
    # Log the summary
    _log_summaries(samples[common.IMAGE], samples[common.LABEL], num_classes,
                   output_type_dict[model.MERGED_LOGITS_SCOPE])
yukun's avatar
yukun committed
252
253


254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def _tower_loss(iterator, num_of_classes, ignore_label, scope, reuse_variable):
  """Calculates the total loss on a single tower running the deeplab model.

  Args:
    iterator: An iterator of type tf.data.Iterator for images and labels.
    num_of_classes: Number of classes for the dataset.
    ignore_label: Ignore label for the dataset.
    scope: Unique prefix string identifying the deeplab tower.
    reuse_variable: If the variable should be reused.

  Returns:
     The total loss for a batch of data.
  """
  with tf.variable_scope(
      tf.get_variable_scope(), reuse=True if reuse_variable else None):
    _build_deeplab(iterator, {common.OUTPUT_TYPE: num_of_classes}, ignore_label)
yukun's avatar
yukun committed
270

271
272
273
  losses = tf.losses.get_losses(scope=scope)
  for loss in losses:
    tf.summary.scalar('Losses/%s' % loss.op.name, loss)
yukun's avatar
yukun committed
274

275
276
277
  regularization_loss = tf.losses.get_regularization_loss(scope=scope)
  tf.summary.scalar('Losses/%s' % regularization_loss.op.name,
                    regularization_loss)
yukun's avatar
yukun committed
278

279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
  total_loss = tf.add_n([tf.add_n(losses), regularization_loss])
  return total_loss


def _average_gradients(tower_grads):
  """Calculates average of gradient for each shared variable across all towers.

  Note that this function provides a synchronization point across all towers.

  Args:
    tower_grads: List of lists of (gradient, variable) tuples. The outer list is
      over individual gradients. The inner list is over the gradient calculation
      for each tower.

  Returns:
     List of pairs of (gradient, variable) where the gradient has been summed
       across all towers.
  """
  average_grads = []
  for grad_and_vars in zip(*tower_grads):
    # Note that each grad_and_vars looks like the following:
    #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
    grads, variables = zip(*grad_and_vars)
    grad = tf.reduce_mean(tf.stack(grads, axis=0), axis=0)

    # All vars are of the same value, using the first tower here.
    average_grads.append((grad, variables[0]))

  return average_grads


def _log_summaries(input_image, label, num_of_classes, output):
  """Logs the summaries for the model.

  Args:
    input_image: Input image of the model. Its shape is [batch_size, height,
      width, channel].
    label: Label of the image. Its shape is [batch_size, height, width].
    num_of_classes: The number of classes of the dataset.
    output: Output of the model. Its shape is [batch_size, height, width].
  """
  # Add summaries for model variables.
  for model_var in tf.model_variables():
    tf.summary.histogram(model_var.op.name, model_var)

  # Add summaries for images, labels, semantic predictions.
  if FLAGS.save_summaries_images:
    tf.summary.image('samples/%s' % common.IMAGE, input_image)

    # Scale up summary image pixel values for better visualization.
    pixel_scaling = max(1, 255 // num_of_classes)
    summary_label = tf.cast(label * pixel_scaling, tf.uint8)
    tf.summary.image('samples/%s' % common.LABEL, summary_label)

    predictions = tf.expand_dims(tf.argmax(output, 3), -1)
    summary_predictions = tf.cast(predictions * pixel_scaling, tf.uint8)
    tf.summary.image('samples/%s' % common.OUTPUT_TYPE, summary_predictions)


def _train_deeplab_model(iterator, num_of_classes, ignore_label):
  """Trains the deeplab model.

  Args:
    iterator: An iterator of type tf.data.Iterator for images and labels.
    num_of_classes: Number of classes for the dataset.
    ignore_label: Ignore label for the dataset.

  Returns:
    train_tensor: A tensor to update the model variables.
    summary_op: An operation to log the summaries.
  """
  global_step = tf.train.get_or_create_global_step()

  learning_rate = train_utils.get_model_learning_rate(
      FLAGS.learning_policy, FLAGS.base_learning_rate,
      FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
      FLAGS.training_number_of_steps, FLAGS.learning_power,
      FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
357
  tf.summary.scalar('learning_rate', learning_rate)
358
359
360

  optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)

361
  tower_losses = []
362
  tower_grads = []
363
  for i in range(FLAGS.num_clones):
364
    with tf.device('/gpu:%d' % i):
365
366
367
      # First tower has default name scope.
      name_scope = ('clone_%d' % i) if i else ''
      with tf.name_scope(name_scope) as scope:
368
369
370
371
372
373
        loss = _tower_loss(
            iterator=iterator,
            num_of_classes=num_of_classes,
            ignore_label=ignore_label,
            scope=scope,
            reuse_variable=(i != 0))
374
375
376
377
378
379
380
        tower_losses.append(loss)

  if FLAGS.quantize_delay_step >= 0:
    if FLAGS.num_clones > 1:
      raise ValueError('Quantization doesn\'t support multi-clone yet.')
    tf.contrib.quantize.create_training_graph(
        quant_delay=FLAGS.quantize_delay_step)
381

382
383
384
385
386
387
  for i in range(FLAGS.num_clones):
    with tf.device('/gpu:%d' % i):
      name_scope = ('clone_%d' % i) if i else ''
      with tf.name_scope(name_scope) as scope:
        grads = optimizer.compute_gradients(tower_losses[i])
        tower_grads.append(grads)
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420

  with tf.device('/cpu:0'):
    grads_and_vars = _average_gradients(tower_grads)

    # Modify the gradients for biases and last layer variables.
    last_layers = model.get_extra_layer_scopes(
        FLAGS.last_layers_contain_logits_only)
    grad_mult = train_utils.get_model_gradient_multipliers(
        last_layers, FLAGS.last_layer_gradient_multiplier)
    if grad_mult:
      grads_and_vars = tf.contrib.training.multiply_gradients(
          grads_and_vars, grad_mult)

    # Create gradient update op.
    grad_updates = optimizer.apply_gradients(
        grads_and_vars, global_step=global_step)

    # Gather update_ops. These contain, for example,
    # the updates for the batch_norm variables created by model_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    update_ops.append(grad_updates)
    update_op = tf.group(*update_ops)

    total_loss = tf.losses.get_total_loss(add_regularization_losses=True)

    # Print total loss to the terminal.
    # This implementation is mirrored from tf.slim.summaries.
    should_log = math_ops.equal(math_ops.mod(global_step, FLAGS.log_steps), 0)
    total_loss = tf.cond(
        should_log,
        lambda: tf.Print(total_loss, [total_loss], 'Total loss is :'),
        lambda: total_loss)

421
    tf.summary.scalar('total_loss', total_loss)
422
423
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')
424
425
426

    # Excludes summaries from towers other than the first one.
    summary_op = tf.summary.merge_all(scope='(?!clone_)')
427
428
429
430
431
432

  return train_tensor, summary_op


def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
yukun's avatar
yukun committed
433
434
435
436

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

437
438
439
440
441
442
443
444
445
446
447
448
  graph = tf.Graph()
  with graph.as_default():
    with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):
      assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
          'Training batch size not divisble by number of clones (GPUs).')
      clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones

      dataset = data_generator.Dataset(
          dataset_name=FLAGS.dataset,
          split_name=FLAGS.train_split,
          dataset_dir=FLAGS.dataset_dir,
          batch_size=clone_batch_size,
449
          crop_size=[int(sz) for sz in FLAGS.train_crop_size],
yukun's avatar
yukun committed
450
451
452
453
454
455
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
456
457
          model_variant=FLAGS.model_variant,
          num_readers=2,
yukun's avatar
yukun committed
458
          is_training=True,
459
460
461
462
463
464
465
466
467
468
469
          should_shuffle=True,
          should_repeat=True)

      train_tensor, summary_op = _train_deeplab_model(
          dataset.get_one_shot_iterator(), dataset.num_of_classes,
          dataset.ignore_label)

      # Soft placement allows placing on CPU ops without GPU implementation.
      session_config = tf.ConfigProto(
          allow_soft_placement=True, log_device_placement=False)

470
471
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
472
473
474
      init_fn = None
      if FLAGS.tf_initial_checkpoint:
        init_fn = train_utils.get_model_init_fn(
yukun's avatar
yukun committed
475
476
477
478
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
479
480
481
482
483
484
485
            ignore_missing_vars=True)

      scaffold = tf.train.Scaffold(
          init_fn=init_fn,
          summary_op=summary_op,
      )

486
487
      stop_hook = tf.train.StopAtStepHook(
          last_step=FLAGS.training_number_of_steps)
488
489
490
491
492
493
494
495
496
497
498
499
500

      profile_dir = FLAGS.profile_logdir
      if profile_dir is not None:
        tf.gfile.MakeDirs(profile_dir)

      with tf.contrib.tfprof.ProfileContext(
          enabled=profile_dir is not None, profile_dir=profile_dir):
        with tf.train.MonitoredTrainingSession(
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            config=session_config,
            scaffold=scaffold,
            checkpoint_dir=FLAGS.train_logdir,
501
            summary_dir=FLAGS.train_logdir,
502
503
504
505
506
507
            log_step_count_steps=FLAGS.log_steps,
            save_summaries_steps=FLAGS.save_summaries_secs,
            save_checkpoint_secs=FLAGS.save_interval_secs,
            hooks=[stop_hook]) as sess:
          while not sess.should_stop():
            sess.run([train_tensor])
yukun's avatar
yukun committed
508
509
510
511
512
513


if __name__ == '__main__':
  flags.mark_flag_as_required('train_logdir')
  flags.mark_flag_as_required('dataset_dir')
  tf.app.run()