trainer.py 16.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Detection model trainer.

This file provides a generic training method that can be used to train a
DetectionModel.
"""

import functools

import tensorflow as tf

from object_detection.builders import optimizer_builder
from object_detection.builders import preprocessor_builder
from object_detection.core import batcher
from object_detection.core import preprocessor
from object_detection.core import standard_fields as fields
from object_detection.utils import ops as util_ops
from object_detection.utils import variables_helper
from deployment import model_deploy

slim = tf.contrib.slim


38
39
40
def create_input_queue(batch_size_per_clone, create_tensor_dict_fn,
                       batch_queue_capacity, num_batch_queue_threads,
                       prefetch_queue_capacity, data_augmentation_options):
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
  """Sets up reader, prefetcher and returns input queue.

  Args:
    batch_size_per_clone: batch size to use per clone.
    create_tensor_dict_fn: function to create tensor dictionary.
    batch_queue_capacity: maximum number of elements to store within a queue.
    num_batch_queue_threads: number of threads to use for batching.
    prefetch_queue_capacity: maximum capacity of the queue used to prefetch
                             assembled batches.
    data_augmentation_options: a list of tuples, where each tuple contains a
      data augmentation function and a dictionary containing arguments and their
      values (see preprocessor.py).

  Returns:
    input queue: a batcher.BatchQueue object holding enqueued tensor_dicts
      (which hold images, boxes and targets).  To get a batch of tensor_dicts,
      call input_queue.Dequeue().
  """
  tensor_dict = create_tensor_dict_fn()

  tensor_dict[fields.InputDataFields.image] = tf.expand_dims(
      tensor_dict[fields.InputDataFields.image], 0)

  images = tensor_dict[fields.InputDataFields.image]
  float_images = tf.to_float(images)
  tensor_dict[fields.InputDataFields.image] = float_images

68
69
70
71
  include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks
                            in tensor_dict)
  include_keypoints = (fields.InputDataFields.groundtruth_keypoints
                       in tensor_dict)
72
73
  include_multiclass_scores = (fields.InputDataFields.multiclass_scores
                               in tensor_dict)
74
  if data_augmentation_options:
75
76
77
    tensor_dict = preprocessor.preprocess(
        tensor_dict, data_augmentation_options,
        func_arg_map=preprocessor.get_default_func_arg_map(
78
            include_multiclass_scores=include_multiclass_scores,
79
80
            include_instance_masks=include_instance_masks,
            include_keypoints=include_keypoints))
81
82
83
84
85
86
87
88
89
90

  input_queue = batcher.BatchQueue(
      tensor_dict,
      batch_size=batch_size_per_clone,
      batch_queue_capacity=batch_queue_capacity,
      num_batch_queue_threads=num_batch_queue_threads,
      prefetch_queue_capacity=prefetch_queue_capacity)
  return input_queue


91
92
93
94
def get_inputs(input_queue,
               num_classes,
               merge_multiple_label_boxes=False,
               use_multiclass_scores=False):
95
  """Dequeues batch and constructs inputs to object detection model.
96
97
98
99

  Args:
    input_queue: BatchQueue object holding enqueued tensor_dicts.
    num_classes: Number of classes.
100
101
102
103
    merge_multiple_label_boxes: Whether to merge boxes with multiple labels
      or not. Defaults to false. Merged boxes are represented with a single
      box and a k-hot encoding of the multiple labels associated with the
      boxes.
104
105
    use_multiclass_scores: Whether to use multiclass scores instead of
      groundtruth_classes.
106
107
108

  Returns:
    images: a list of 3-D float tensor of images.
109
    image_keys: a list of string keys for the images.
110
111
112
113
114
115
    locations_list: a list of tensors of shape [num_boxes, 4]
      containing the corners of the groundtruth boxes.
    classes_list: a list of padded one-hot tensors containing target classes.
    masks_list: a list of 3-D float tensors of shape [num_boxes, image_height,
      image_width] containing instance masks for objects if present in the
      input_queue. Else returns None.
116
117
118
    keypoints_list: a list of 3-D float tensors of shape [num_boxes,
      num_keypoints, 2] containing keypoints for objects if present in the
      input queue. Else returns None.
119
120
    weights_lists: a list of 1-D float32 tensors of shape [num_boxes]
      containing groundtruth weight for each box.
121
122
123
124
  """
  read_data_list = input_queue.dequeue()
  label_id_offset = 1
  def extract_images_and_targets(read_data):
125
    """Extract images and targets from the input dict."""
126
    image = read_data[fields.InputDataFields.image]
127
128
129
    key = ''
    if fields.InputDataFields.source_id in read_data:
      key = read_data[fields.InputDataFields.source_id]
130
131
132
133
    location_gt = read_data[fields.InputDataFields.groundtruth_boxes]
    classes_gt = tf.cast(read_data[fields.InputDataFields.groundtruth_classes],
                         tf.int32)
    classes_gt -= label_id_offset
134
135
136
137
138
139
140

    if merge_multiple_label_boxes and use_multiclass_scores:
      raise ValueError(
          'Using both merge_multiple_label_boxes and use_multiclass_scores is'
          'not supported'
      )

141
142
143
    if merge_multiple_label_boxes:
      location_gt, classes_gt, _ = util_ops.merge_boxes_with_multiple_labels(
          location_gt, classes_gt, num_classes)
144
145
146
    elif use_multiclass_scores:
      classes_gt = tf.cast(read_data[fields.InputDataFields.multiclass_scores],
                           tf.float32)
147
148
149
    else:
      classes_gt = util_ops.padded_one_hot_encoding(
          indices=classes_gt, depth=num_classes, left_pad=0)
150
    masks_gt = read_data.get(fields.InputDataFields.groundtruth_instance_masks)
151
152
153
154
    keypoints_gt = read_data.get(fields.InputDataFields.groundtruth_keypoints)
    if (merge_multiple_label_boxes and (
        masks_gt is not None or keypoints_gt is not None)):
      raise NotImplementedError('Multi-label support is only for boxes.')
155
156
157
158
    weights_gt = read_data.get(
        fields.InputDataFields.groundtruth_weights)
    return (image, key, location_gt, classes_gt, masks_gt, keypoints_gt,
            weights_gt)
159

160
161
162
  return zip(*map(extract_images_and_targets, read_data_list))


163
def _create_losses(input_queue, create_model_fn, train_config):
164
165
166
167
168
  """Creates loss function for a DetectionModel.

  Args:
    input_queue: BatchQueue object holding enqueued tensor_dicts.
    create_model_fn: A function to create the DetectionModel.
169
    train_config: a train_pb2.TrainConfig protobuf.
170
171
  """
  detection_model = create_model_fn()
172
  (images, _, groundtruth_boxes_list, groundtruth_classes_list,
173
   groundtruth_masks_list, groundtruth_keypoints_list, _) = get_inputs(
174
175
       input_queue,
       detection_model.num_classes,
176
177
       train_config.merge_multiple_label_boxes,
       train_config.use_multiclass_scores)
178
179
180
181
182
183
184
185
186
187
188

  preprocessed_images = []
  true_image_shapes = []
  for image in images:
    resized_image, true_image_shape = detection_model.preprocess(image)
    preprocessed_images.append(resized_image)
    true_image_shapes.append(true_image_shape)

  images = tf.concat(preprocessed_images, 0)
  true_image_shapes = tf.concat(true_image_shapes, 0)

189
190
  if any(mask is None for mask in groundtruth_masks_list):
    groundtruth_masks_list = None
191
192
  if any(keypoints is None for keypoints in groundtruth_keypoints_list):
    groundtruth_keypoints_list = None
193
194
195

  detection_model.provide_groundtruth(groundtruth_boxes_list,
                                      groundtruth_classes_list,
196
197
                                      groundtruth_masks_list,
                                      groundtruth_keypoints_list)
198
  prediction_dict = detection_model.predict(images, true_image_shapes)
199

200
  losses_dict = detection_model.loss(prediction_dict, true_image_shapes)
201
202
203
204
  for loss_tensor in losses_dict.values():
    tf.losses.add_loss(loss_tensor)


205
206
207
208
209
210
211
212
213
214
215
216
217
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
  """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
234
235
236
237
    graph_hook_fn: Optional function that is called after the training graph is
      completely built. This is helpful to perform additional changes to the
      training graph such as optimizing batchnorm. The function should modify
      the default graph.
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
  """

  detection_model = create_model_fn()
  data_augmentation_options = [
      preprocessor_builder.build(step)
      for step in train_config.data_augmentation_options]

  with tf.Graph().as_default():
    # Build a configuration specifying multi-GPU and multi-replicas.
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=num_clones,
        clone_on_cpu=clone_on_cpu,
        replica_id=task,
        num_replicas=worker_replicas,
        num_ps_tasks=ps_tasks,
        worker_job_name=worker_job_name)

    # Place the global step on the device storing the variables.
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    with tf.device(deploy_config.inputs_device()):
260
261
262
263
264
      input_queue = create_input_queue(
          train_config.batch_size // num_clones, create_tensor_dict_fn,
          train_config.batch_queue_capacity,
          train_config.num_batch_queue_threads,
          train_config.prefetch_queue_capacity, data_augmentation_options)
265
266

    # Gather initial summaries.
267
    # TODO(rathodv): See if summaries can be added/extracted from global tf
268
    # collections so that they don't have to be passed around.
269
270
271
272
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
    global_summaries = set([])

    model_fn = functools.partial(_create_losses,
273
274
                                 create_model_fn=create_model_fn,
                                 train_config=train_config)
275
276
277
278
279
280
281
282
    clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
    first_clone_scope = clones[0].scope

    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by model_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    with tf.device(deploy_config.optimizer_device()):
283
284
285
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)
      for var in optimizer_summary_vars:
286
        tf.summary.scalar(var.op.name, var, family='LearningRate')
287
288
289

    sync_optimizer = None
    if train_config.sync_replicas:
290
      training_optimizer = tf.train.SyncReplicasOptimizer(
291
292
          training_optimizer,
          replicas_to_aggregate=train_config.replicas_to_aggregate,
293
          total_num_replicas=worker_replicas)
294
295
296
      sync_optimizer = training_optimizer

    with tf.device(deploy_config.optimizer_device()):
297
298
      regularization_losses = (None if train_config.add_regularization_loss
                               else [])
299
      total_loss, grads_and_vars = model_deploy.optimize_clones(
300
301
          clones, training_optimizer,
          regularization_losses=regularization_losses)
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
      total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')

      # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
      if train_config.bias_grad_multiplier:
        biases_regex_list = ['.*/biases']
        grads_and_vars = variables_helper.multiply_gradients_matching_regex(
            grads_and_vars,
            biases_regex_list,
            multiplier=train_config.bias_grad_multiplier)

      # Optionally freeze some layers by setting their gradients to be zero.
      if train_config.freeze_variables:
        grads_and_vars = variables_helper.freeze_gradients_matching_regex(
            grads_and_vars, train_config.freeze_variables)

      # Optionally clip gradients
      if train_config.gradient_clipping_by_norm > 0:
        with tf.name_scope('clip_grads'):
          grads_and_vars = slim.learning.clip_gradient_norms(
              grads_and_vars, train_config.gradient_clipping_by_norm)

      # Create gradient updates.
      grad_updates = training_optimizer.apply_gradients(grads_and_vars,
                                                        global_step=global_step)
      update_ops.append(grad_updates)
327
      update_op = tf.group(*update_ops, name='update_barrier')
328
329
330
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

331
332
333
334
    if graph_hook_fn:
      with tf.device(deploy_config.variables_device()):
        graph_hook_fn()

335
336
    # Add summaries.
    for model_var in slim.get_model_variables():
337
338
      global_summaries.add(tf.summary.histogram('ModelVars/' +
                                                model_var.op.name, model_var))
339
    for loss_tensor in tf.losses.get_losses():
340
341
      global_summaries.add(tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                             loss_tensor))
342
    global_summaries.add(
343
        tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))
    summaries |= global_summaries

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)

    # Save checkpoints regularly.
    keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
    saver = tf.train.Saver(
        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    # Create ops required to initialize the model from a given checkpoint.
    init_fn = None
    if train_config.fine_tune_checkpoint:
      if not train_config.fine_tune_checkpoint_type:
        # train_config.from_detection_checkpoint field is deprecated. For
        # backward compatibility, fine_tune_checkpoint_type is set based on
        # from_detection_checkpoint.
        if train_config.from_detection_checkpoint:
          train_config.fine_tune_checkpoint_type = 'detection'
        else:
          train_config.fine_tune_checkpoint_type = 'classification'
      var_map = detection_model.restore_map(
          fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
          load_all_detection_checkpoint_vars=(
              train_config.load_all_detection_checkpoint_vars))
      available_var_map = (variables_helper.
                           get_variables_available_in_checkpoint(
                               var_map, train_config.fine_tune_checkpoint))
      init_saver = tf.train.Saver(available_var_map)
      def initializer_fn(sess):
        init_saver.restore(sess, train_config.fine_tune_checkpoint)
      init_fn = initializer_fn

386
387
388
389
390
391
392
393
394
395
396
397
398
399
    slim.learning.train(
        train_tensor,
        logdir=train_dir,
        master=master,
        is_chief=is_chief,
        session_config=session_config,
        startup_delay_steps=train_config.startup_delay_steps,
        init_fn=init_fn,
        summary_op=summary_op,
        number_of_steps=(
            train_config.num_steps if train_config.num_steps else None),
        save_summaries_secs=120,
        sync_optimizer=sync_optimizer,
        saver=saver)