model_lib.py 35.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
r"""Constructs model, inputs, and training environment."""
16
17
18
19
20

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

21
import copy
22
import functools
23
import os
24
25
26

import tensorflow as tf

27
from tensorflow.python.util import function_utils
28
from object_detection import eval_util
29
from object_detection import exporter as exporter_lib
30
from object_detection import inputs
31
from object_detection.builders import graph_rewriter_builder
32
33
34
35
36
from object_detection.builders import model_builder
from object_detection.builders import optimizer_builder
from object_detection.core import standard_fields as fields
from object_detection.utils import config_util
from object_detection.utils import label_map_util
37
from object_detection.utils import ops
38
39
40
41
from object_detection.utils import shape_utils
from object_detection.utils import variables_helper
from object_detection.utils import visualization_utils as vis_utils

42
43
44
45
46
47
48
49
# A map of names to methods that help build the model.
MODEL_BUILD_UTIL_MAP = {
    'get_configs_from_pipeline_file':
        config_util.get_configs_from_pipeline_file,
    'create_pipeline_proto_from_configs':
        config_util.create_pipeline_proto_from_configs,
    'merge_external_params_with_configs':
        config_util.merge_external_params_with_configs,
50
51
52
53
54
55
    'create_train_input_fn':
        inputs.create_train_input_fn,
    'create_eval_input_fn':
        inputs.create_eval_input_fn,
    'create_predict_input_fn':
        inputs.create_predict_input_fn,
56
    'detection_model_fn_base': model_builder.build,
57
58
59
}


60
61
def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
                                  max_number_of_boxes):
62
  """Extracts groundtruth data from detection_model and prepares it for eval.
63
64
65
66

  Args:
    detection_model: A `DetectionModel` object.
    class_agnostic: Whether the detections are class_agnostic.
67
    max_number_of_boxes: Max number of groundtruth boxes.
68
69
70
71

  Returns:
    A tuple of:
    groundtruth: Dictionary with the following fields:
72
73
74
75
76
      'groundtruth_boxes': [batch_size, num_boxes, 4] float32 tensor of boxes,
        in normalized coordinates.
      'groundtruth_classes': [batch_size, num_boxes] int64 tensor of 1-indexed
        classes.
      'groundtruth_masks': 4D float32 tensor of instance masks (if provided in
77
        groundtruth)
78
79
80
81
      'groundtruth_is_crowd': [batch_size, num_boxes] bool tensor indicating
        is_crowd annotations (if provided in groundtruth).
      'num_groundtruth_boxes': [batch_size] tensor containing the maximum number
        of groundtruth boxes per image..
82
83
84
    class_agnostic: Boolean indicating whether detections are class agnostic.
  """
  input_data_fields = fields.InputDataFields()
85
86
87
  groundtruth_boxes = tf.stack(
      detection_model.groundtruth_lists(fields.BoxListFields.boxes))
  groundtruth_boxes_shape = tf.shape(groundtruth_boxes)
88
89
90
  # For class-agnostic models, groundtruth one-hot encodings collapse to all
  # ones.
  if class_agnostic:
91
92
    groundtruth_classes_one_hot = tf.ones(
        [groundtruth_boxes_shape[0], groundtruth_boxes_shape[1], 1])
93
  else:
94
95
    groundtruth_classes_one_hot = tf.stack(
        detection_model.groundtruth_lists(fields.BoxListFields.classes))
96
97
  label_id_offset = 1  # Applying label id offset (b/63711816)
  groundtruth_classes = (
98
      tf.argmax(groundtruth_classes_one_hot, axis=2) + label_id_offset)
99
100
101
102
103
  groundtruth = {
      input_data_fields.groundtruth_boxes: groundtruth_boxes,
      input_data_fields.groundtruth_classes: groundtruth_classes
  }
  if detection_model.groundtruth_has_field(fields.BoxListFields.masks):
104
105
106
    groundtruth[input_data_fields.groundtruth_instance_masks] = tf.stack(
        detection_model.groundtruth_lists(fields.BoxListFields.masks))

107
  if detection_model.groundtruth_has_field(fields.BoxListFields.is_crowd):
108
109
110
111
112
    groundtruth[input_data_fields.groundtruth_is_crowd] = tf.stack(
        detection_model.groundtruth_lists(fields.BoxListFields.is_crowd))

  groundtruth[input_data_fields.num_groundtruth_boxes] = (
      tf.tile([max_number_of_boxes], multiples=[groundtruth_boxes_shape[0]]))
113
114
115
116
117
118
119
  return groundtruth


def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
  """Unstacks all tensors in `tensor_dict` along 0th dimension.

  Unstacks tensor from the tensor dict along 0th dimension and returns a
120
  tensor_dict containing values that are lists of unstacked, unpadded tensors.
121
122
123
124
125
126

  Tensors in the `tensor_dict` are expected to be of one of the three shapes:
  1. [batch_size]
  2. [batch_size, height, width, channels]
  3. [batch_size, num_boxes, d1, d2, ... dn]

127
128
  When unpad_groundtruth_tensors is set to true, unstacked tensors of form 3
  above are sliced along the `num_boxes` dimension using the value in tensor
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
  field.InputDataFields.num_groundtruth_boxes.

  Note that this function has a static list of input data fields and has to be
  kept in sync with the InputDataFields defined in core/standard_fields.py

  Args:
    tensor_dict: A dictionary of batched groundtruth tensors.
    unpad_groundtruth_tensors: Whether to remove padding along `num_boxes`
      dimension of the groundtruth tensors.

  Returns:
    A dictionary where the keys are from fields.InputDataFields and values are
    a list of unstacked (optionally unpadded) tensors.

  Raises:
    ValueError: If unpad_tensors is True and `tensor_dict` does not contain
      `num_groundtruth_boxes` tensor.
  """
147
148
149
  unbatched_tensor_dict = {
      key: tf.unstack(tensor) for key, tensor in tensor_dict.items()
  }
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
  if unpad_groundtruth_tensors:
    if (fields.InputDataFields.num_groundtruth_boxes not in
        unbatched_tensor_dict):
      raise ValueError('`num_groundtruth_boxes` not found in tensor_dict. '
                       'Keys available: {}'.format(
                           unbatched_tensor_dict.keys()))
    unbatched_unpadded_tensor_dict = {}
    unpad_keys = set([
        # List of input data fields that are padded along the num_boxes
        # dimension. This list has to be kept in sync with InputDataFields in
        # standard_fields.py.
        fields.InputDataFields.groundtruth_instance_masks,
        fields.InputDataFields.groundtruth_classes,
        fields.InputDataFields.groundtruth_boxes,
        fields.InputDataFields.groundtruth_keypoints,
        fields.InputDataFields.groundtruth_group_of,
        fields.InputDataFields.groundtruth_difficult,
        fields.InputDataFields.groundtruth_is_crowd,
        fields.InputDataFields.groundtruth_area,
        fields.InputDataFields.groundtruth_weights
    ]).intersection(set(unbatched_tensor_dict.keys()))

    for key in unpad_keys:
      unpadded_tensor_list = []
      for num_gt, padded_tensor in zip(
          unbatched_tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
          unbatched_tensor_dict[key]):
        tensor_shape = shape_utils.combined_static_and_dynamic_shape(
            padded_tensor)
        slice_begin = tf.zeros([len(tensor_shape)], dtype=tf.int32)
        slice_size = tf.stack(
            [num_gt] + [-1 if dim is None else dim for dim in tensor_shape[1:]])
        unpadded_tensor = tf.slice(padded_tensor, slice_begin, slice_size)
        unpadded_tensor_list.append(unpadded_tensor)
      unbatched_unpadded_tensor_dict[key] = unpadded_tensor_list
    unbatched_tensor_dict.update(unbatched_unpadded_tensor_dict)

  return unbatched_tensor_dict


190
191
def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False,
                    postprocess_on_cpu=False):
192
193
194
195
196
197
198
199
  """Creates a model function for `Estimator`.

  Args:
    detection_model_fn: Function that returns a `DetectionModel` instance.
    configs: Dictionary of pipeline config objects.
    hparams: `HParams` object.
    use_tpu: Boolean indicating whether model should be constructed for
        use on TPU.
200
201
    postprocess_on_cpu: When use_tpu and postprocess_on_cpu is true, postprocess
        is scheduled on the host cpu.
202
203
204
205
206
207

  Returns:
    `model_fn` for `Estimator`.
  """
  train_config = configs['train_config']
  eval_input_config = configs['eval_input_config']
208
  eval_config = configs['eval_config']
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

  def model_fn(features, labels, mode, params=None):
    """Constructs the object detection model.

    Args:
      features: Dictionary of feature tensors, returned from `input_fn`.
      labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
        otherwise None.
      mode: Mode key from tf.estimator.ModeKeys.
      params: Parameter dictionary passed from the estimator.

    Returns:
      An `EstimatorSpec` that encapsulates the model and its serving
        configurations.
    """
    params = params or {}
    total_loss, train_op, detections, export_outputs = None, None, None, None
    is_training = mode == tf.estimator.ModeKeys.TRAIN
227
228
229
230

    # Make sure to set the Keras learning phase. True during training,
    # False for inference.
    tf.keras.backend.set_learning_phase(is_training)
231
232
    detection_model = detection_model_fn(
        is_training=is_training, add_summaries=(not use_tpu))
233
234
235
236
237
238
239
    scaffold_fn = None

    if mode == tf.estimator.ModeKeys.TRAIN:
      labels = unstack_batch(
          labels,
          unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors)
    elif mode == tf.estimator.ModeKeys.EVAL:
240
241
242
243
244
      # For evaling on train data, it is necessary to check whether groundtruth
      # must be unpadded.
      boxes_shape = (
          labels[fields.InputDataFields.groundtruth_boxes].get_shape()
          .as_list())
245
      unpad_groundtruth_tensors = boxes_shape[1] is not None and not use_tpu
246
247
      labels = unstack_batch(
          labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
248
249
250
251
252
253
254
255
256
257
258

    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
      gt_classes_list = labels[fields.InputDataFields.groundtruth_classes]
      gt_masks_list = None
      if fields.InputDataFields.groundtruth_instance_masks in labels:
        gt_masks_list = labels[
            fields.InputDataFields.groundtruth_instance_masks]
      gt_keypoints_list = None
      if fields.InputDataFields.groundtruth_keypoints in labels:
        gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
259
260
261
      gt_weights_list = None
      if fields.InputDataFields.groundtruth_weights in labels:
        gt_weights_list = labels[fields.InputDataFields.groundtruth_weights]
262
263
264
265
      gt_confidences_list = None
      if fields.InputDataFields.groundtruth_confidences in labels:
        gt_confidences_list = labels[
            fields.InputDataFields.groundtruth_confidences]
266
      gt_is_crowd_list = None
267
268
      if fields.InputDataFields.groundtruth_is_crowd in labels:
        gt_is_crowd_list = labels[fields.InputDataFields.groundtruth_is_crowd]
269
270
271
      detection_model.provide_groundtruth(
          groundtruth_boxes_list=gt_boxes_list,
          groundtruth_classes_list=gt_classes_list,
272
          groundtruth_confidences_list=gt_confidences_list,
273
          groundtruth_masks_list=gt_masks_list,
274
          groundtruth_keypoints_list=gt_keypoints_list,
275
          groundtruth_weights_list=gt_weights_list,
276
          groundtruth_is_crowd_list=gt_is_crowd_list)
277
278

    preprocessed_images = features[fields.InputDataFields.image]
279
280
281
282
283
    if use_tpu and train_config.use_bfloat16:
      with tf.contrib.tpu.bfloat16_scope():
        prediction_dict = detection_model.predict(
            preprocessed_images,
            features[fields.InputDataFields.true_image_shape])
284
        prediction_dict = ops.bfloat16_to_float32_nested(prediction_dict)
285
286
287
288
    else:
      prediction_dict = detection_model.predict(
          preprocessed_images,
          features[fields.InputDataFields.true_image_shape])
289
290
291
292

    def postprocess_wrapper(args):
      return detection_model.postprocess(args[0], args[1])

293
    if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
294
295
296
297
298
299
300
301
302
      if use_tpu and postprocess_on_cpu:
        detections = tf.contrib.tpu.outside_compilation(
            postprocess_wrapper,
            (prediction_dict,
             features[fields.InputDataFields.true_image_shape]))
      else:
        detections = postprocess_wrapper((
            prediction_dict,
            features[fields.InputDataFields.true_image_shape]))
303
304
305

    if mode == tf.estimator.ModeKeys.TRAIN:
      if train_config.fine_tune_checkpoint and hparams.load_pretrained:
306
307
308
309
310
311
312
313
        if not train_config.fine_tune_checkpoint_type:
          # train_config.from_detection_checkpoint field is deprecated. For
          # backward compatibility, set train_config.fine_tune_checkpoint_type
          # based on train_config.from_detection_checkpoint.
          if train_config.from_detection_checkpoint:
            train_config.fine_tune_checkpoint_type = 'detection'
          else:
            train_config.fine_tune_checkpoint_type = 'classification'
314
        asg_map = detection_model.restore_map(
315
            fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
316
317
318
319
            load_all_detection_checkpoint_vars=(
                train_config.load_all_detection_checkpoint_vars))
        available_var_map = (
            variables_helper.get_variables_available_in_checkpoint(
320
321
                asg_map,
                train_config.fine_tune_checkpoint,
322
323
                include_global_step=False))
        if use_tpu:
324

325
326
327
328
          def tpu_scaffold():
            tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                          available_var_map)
            return tf.train.Scaffold()
329

330
331
332
333
334
335
336
337
          scaffold_fn = tpu_scaffold
        else:
          tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                        available_var_map)

    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      losses_dict = detection_model.loss(
          prediction_dict, features[fields.InputDataFields.true_image_shape])
338
      losses = [loss_tensor for loss_tensor in losses_dict.values()]
339
      if train_config.add_regularization_loss:
340
        regularization_losses = detection_model.regularization_losses()
341
342
343
        if use_tpu and train_config.use_bfloat16:
          regularization_losses = ops.bfloat16_to_float32_nested(
              regularization_losses)
344
        if regularization_losses:
345
346
          regularization_loss = tf.add_n(
              regularization_losses, name='regularization_loss')
347
          losses.append(regularization_loss)
348
          losses_dict['Loss/regularization_loss'] = regularization_loss
349
      total_loss = tf.add_n(losses, name='total_loss')
350
      losses_dict['Loss/total_loss'] = total_loss
351

352
353
354
355
356
      if 'graph_rewriter_config' in configs:
        graph_rewriter_fn = graph_rewriter_builder.build(
            configs['graph_rewriter_config'], is_training=is_training)
        graph_rewriter_fn()

357
358
      # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we
      # can write learning rate summaries on TPU without host calls.
359
360
361
362
      global_step = tf.train.get_or_create_global_step()
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)

363
    if mode == tf.estimator.ModeKeys.TRAIN:
364
      if use_tpu:
365
        training_optimizer = tf.contrib.tpu.CrossShardOptimizer(
366
367
368
369
            training_optimizer)

      # Optionally freeze some layers by setting their gradients to be zero.
      trainable_variables = None
370
371
372
373
374
375
376
377
378
379
      include_variables = (
          train_config.update_trainable_variables
          if train_config.update_trainable_variables else None)
      exclude_variables = (
          train_config.freeze_variables
          if train_config.freeze_variables else None)
      trainable_variables = tf.contrib.framework.filter_variables(
          tf.trainable_variables(),
          include_patterns=include_variables,
          exclude_patterns=exclude_variables)
380
381
382
383
384
385
386
387
388

      clip_gradients_value = None
      if train_config.gradient_clipping_by_norm > 0:
        clip_gradients_value = train_config.gradient_clipping_by_norm

      if not use_tpu:
        for var in optimizer_summary_vars:
          tf.summary.scalar(var.op.name, var)
      summaries = [] if use_tpu else None
389
390
      if train_config.summarize_gradients:
        summaries = ['gradients', 'gradient_norm', 'global_gradient_norm']
391
392
393
394
395
396
      train_op = tf.contrib.layers.optimize_loss(
          loss=total_loss,
          global_step=global_step,
          learning_rate=None,
          clip_gradients=clip_gradients_value,
          optimizer=training_optimizer,
397
          update_ops=detection_model.updates(),
398
399
400
401
402
          variables=trainable_variables,
          summaries=summaries,
          name='')  # Preventing scope prefix on all variables.

    if mode == tf.estimator.ModeKeys.PREDICT:
403
      exported_output = exporter_lib.add_output_tensor_nodes(detections)
404
405
      export_outputs = {
          tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
406
              tf.estimator.export.PredictOutput(exported_output)
407
408
409
      }

    eval_metric_ops = None
410
    scaffold = None
411
    if mode == tf.estimator.ModeKeys.EVAL:
412
413
      class_agnostic = (
          fields.DetectionResultFields.detection_classes not in detections)
414
415
416
      groundtruth = _prepare_groundtruth_for_eval(
          detection_model, class_agnostic,
          eval_input_config.max_number_of_boxes)
417
      use_original_images = fields.InputDataFields.original_image in features
pkulzc's avatar
pkulzc committed
418
      if use_original_images:
419
420
421
422
423
        eval_images = features[fields.InputDataFields.original_image]
        true_image_shapes = tf.slice(
            features[fields.InputDataFields.true_image_shape], [0, 0], [-1, 3])
        original_image_spatial_shapes = features[fields.InputDataFields
                                                 .original_image_spatial_shape]
pkulzc's avatar
pkulzc committed
424
425
      else:
        eval_images = features[fields.InputDataFields.image]
426
427
        true_image_shapes = None
        original_image_spatial_shapes = None
pkulzc's avatar
pkulzc committed
428

429
430
431
      eval_dict = eval_util.result_dict_for_batched_example(
          eval_images,
          features[inputs.HASH_KEY],
432
433
434
          detections,
          groundtruth,
          class_agnostic=class_agnostic,
435
436
437
          scale_to_absolute=True,
          original_image_spatial_shapes=original_image_spatial_shapes,
          true_image_shapes=true_image_shapes)
438
439
440
441
442
443

      if class_agnostic:
        category_index = label_map_util.create_class_agnostic_category_index()
      else:
        category_index = label_map_util.create_category_index_from_labelmap(
            eval_input_config.label_map_path)
444
      vis_metric_ops = None
445
      if not use_tpu and use_original_images:
446
447
448
449
450
451
452
453
        eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections(
            category_index,
            max_examples_to_draw=eval_config.num_visualizations,
            max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
            min_score_thresh=eval_config.min_score_threshold,
            use_normalized_coordinates=False)
        vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops(
            eval_dict)
454

455
456
      # Eval metrics on a single example.
      eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
DefineFC's avatar
DefineFC committed
457
          eval_config, list(category_index.values()), eval_dict)
458
459
460
461
      for loss_key, loss_tensor in iter(losses_dict.items()):
        eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
      for var in optimizer_summary_vars:
        eval_metric_ops[var.op.name] = (var, tf.no_op())
462
463
      if vis_metric_ops is not None:
        eval_metric_ops.update(vis_metric_ops)
464
      eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}
465

466
467
468
469
470
471
472
473
474
475
      if eval_config.use_moving_averages:
        variable_averages = tf.train.ExponentialMovingAverage(0.0)
        variables_to_restore = variable_averages.variables_to_restore()
        keep_checkpoint_every_n_hours = (
            train_config.keep_checkpoint_every_n_hours)
        saver = tf.train.Saver(
            variables_to_restore,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
        scaffold = tf.train.Scaffold(saver=saver)

476
477
    # EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
    if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
478
479
480
481
482
483
484
485
486
      return tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          scaffold_fn=scaffold_fn,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metrics=eval_metric_ops,
          export_outputs=export_outputs)
    else:
487
488
489
490
491
492
493
494
495
      if scaffold is None:
        keep_checkpoint_every_n_hours = (
            train_config.keep_checkpoint_every_n_hours)
        saver = tf.train.Saver(
            sharded=True,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
            save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        scaffold = tf.train.Scaffold(saver=saver)
496
497
498
499
500
501
      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metric_ops=eval_metric_ops,
502
503
          export_outputs=export_outputs,
          scaffold=scaffold)
504
505
506
507

  return model_fn


508
509
510
def create_estimator_and_inputs(run_config,
                                hparams,
                                pipeline_config_path,
511
                                config_override=None,
512
                                train_steps=None,
513
514
                                sample_1_of_n_eval_examples=1,
                                sample_1_of_n_eval_on_train_examples=1,
515
516
517
518
519
                                model_fn_creator=create_model_fn,
                                use_tpu_estimator=False,
                                use_tpu=False,
                                num_shards=1,
                                params=None,
520
                                override_eval_num_epochs=True,
521
                                save_final_config=False,
522
523
                                postprocess_on_cpu=False,
                                export_to_tpu=None,
524
525
                                **kwargs):
  """Creates `Estimator`, input functions, and steps.
526
527
528
529
530

  Args:
    run_config: A `RunConfig`.
    hparams: A `HParams`.
    pipeline_config_path: A path to a pipeline config file.
531
532
    config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to
      override the config from `pipeline_config_path`.
533
534
    train_steps: Number of training steps. If None, the number of training steps
      is set from the `TrainConfig` proto.
535
536
537
538
539
    sample_1_of_n_eval_examples: Integer representing how often an eval example
      should be sampled. If 1, will sample all examples.
    sample_1_of_n_eval_on_train_examples: Similar to
      `sample_1_of_n_eval_examples`, except controls the sampling of training
      data for evaluation.
540
541
542
543
544
545
546
547
548
549
    model_fn_creator: A function that creates a `model_fn` for `Estimator`.
      Follows the signature:

      * Args:
        * `detection_model_fn`: Function that returns `DetectionModel` instance.
        * `configs`: Dictionary of pipeline config objects.
        * `hparams`: `HParams` object.
      * Returns:
        `model_fn` for `Estimator`.

550
551
552
553
554
555
556
557
    use_tpu_estimator: Whether a `TPUEstimator` should be returned. If False,
      an `Estimator` will be returned.
    use_tpu: Boolean, whether training and evaluation should run on TPU. Only
      used if `use_tpu_estimator` is True.
    num_shards: Number of shards (TPU cores). Only used if `use_tpu_estimator`
      is True.
    params: Parameter dictionary passed from the estimator. Only used if
      `use_tpu_estimator` is True.
558
559
    override_eval_num_epochs: Whether to overwrite the number of epochs to 1 for
      eval_input.
560
561
    save_final_config: Whether to save final config (obtained after applying
      overrides) to `estimator.model_dir`.
562
563
564
565
566
    postprocess_on_cpu: When use_tpu and postprocess_on_cpu are true,
      postprocess is scheduled on the host cpu.
    export_to_tpu: When use_tpu and export_to_tpu are true,
      `export_savedmodel()` exports a metagraph for serving on TPU besides the
      one on CPU.
567
568
569
    **kwargs: Additional keyword arguments for configuration override.

  Returns:
570
571
572
    A dictionary with the following fields:
    'estimator': An `Estimator` or `TPUEstimator`.
    'train_input_fn': A training input function.
573
574
    'eval_input_fns': A list of all evaluation input functions.
    'eval_input_names': A list of names for each evaluation input.
575
    'eval_on_train_input_fn': An evaluation-on-train input function.
576
577
578
    'predict_input_fn': A prediction input function.
    'train_steps': Number of training steps. Either directly from input or from
      configuration.
579
  """
580
581
582
583
  get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
      'get_configs_from_pipeline_file']
  merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
      'merge_external_params_with_configs']
584
585
  create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
      'create_pipeline_proto_from_configs']
586
587
588
  create_train_input_fn = MODEL_BUILD_UTIL_MAP['create_train_input_fn']
  create_eval_input_fn = MODEL_BUILD_UTIL_MAP['create_eval_input_fn']
  create_predict_input_fn = MODEL_BUILD_UTIL_MAP['create_predict_input_fn']
589
  detection_model_fn_base = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']
590

591
592
  configs = get_configs_from_pipeline_file(
      pipeline_config_path, config_override=config_override)
593
594
  kwargs.update({
      'train_steps': train_steps,
595
596
      'sample_1_of_n_eval_examples': sample_1_of_n_eval_examples,
      'use_bfloat16': configs['train_config'].use_bfloat16 and use_tpu
597
598
599
600
601
  })
  if override_eval_num_epochs:
    kwargs.update({'eval_num_epochs': 1})
    tf.logging.warning(
        'Forced number of epochs for all eval validations to be 1.')
602
  configs = merge_external_params_with_configs(
603
      configs, hparams, kwargs_dict=kwargs)
604
605
606
607
  model_config = configs['model']
  train_config = configs['train_config']
  train_input_config = configs['train_input_config']
  eval_config = configs['eval_config']
608
609
610
611
612
613
614
615
616
617
618
  eval_input_configs = configs['eval_input_configs']
  eval_on_train_input_config = copy.deepcopy(train_input_config)
  eval_on_train_input_config.sample_1_of_n_examples = (
      sample_1_of_n_eval_on_train_examples)
  if override_eval_num_epochs and eval_on_train_input_config.num_epochs != 1:
    tf.logging.warning('Expected number of evaluation epochs is 1, but '
                       'instead encountered `eval_on_train_input_config'
                       '.num_epochs` = '
                       '{}. Overwriting `num_epochs` to 1.'.format(
                           eval_on_train_input_config.num_epochs))
    eval_on_train_input_config.num_epochs = 1
619

620
621
622
  # update train_steps from config but only when non-zero value is provided
  if train_steps is None and train_config.num_steps != 0:
    train_steps = train_config.num_steps
623
624

  detection_model_fn = functools.partial(
625
      detection_model_fn_base, model_config=model_config)
626

627
  # Create the input functions for TRAIN/EVAL/PREDICT.
628
  train_input_fn = create_train_input_fn(
629
630
631
      train_config=train_config,
      train_input_config=train_input_config,
      model_config=model_config)
632
633
634
635
636
637
638
639
640
  eval_input_fns = [
      create_eval_input_fn(
          eval_config=eval_config,
          eval_input_config=eval_input_config,
          model_config=model_config) for eval_input_config in eval_input_configs
  ]
  eval_input_names = [
      eval_input_config.name for eval_input_config in eval_input_configs
  ]
641
642
  eval_on_train_input_fn = create_eval_input_fn(
      eval_config=eval_config,
643
      eval_input_config=eval_on_train_input_config,
644
      model_config=model_config)
645
  predict_input_fn = create_predict_input_fn(
646
      model_config=model_config, predict_input_config=eval_input_configs[0])
647

648
649
650
  # Read export_to_tpu from hparams if not passed.
  if export_to_tpu is None:
    export_to_tpu = hparams.get('export_to_tpu', False)
651
652
  tf.logging.info('create_estimator_and_inputs: use_tpu %s, export_to_tpu %s',
                  use_tpu, export_to_tpu)
653
654
  model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu,
                              postprocess_on_cpu)
655
  if use_tpu_estimator:
656
657
658
659
660
    # Multicore inference disabled due to b/129367127
    tpu_estimator_args = function_utils.fn_args(tf.contrib.tpu.TPUEstimator)
    kwargs = {}
    if 'experimental_export_device_assignment' in tpu_estimator_args:
      kwargs['experimental_export_device_assignment'] = True
661
    estimator = tf.contrib.tpu.TPUEstimator(
662
663
664
665
666
667
        model_fn=model_fn,
        train_batch_size=train_config.batch_size,
        # For each core, only batch size 1 is supported for eval.
        eval_batch_size=num_shards * 1 if use_tpu else 1,
        use_tpu=use_tpu,
        config=run_config,
668
669
        export_to_tpu=export_to_tpu,
        eval_on_tpu=False,  # Eval runs on CPU, so disable eval on TPU
670
671
        params=params if params else {},
        **kwargs)
672
673
  else:
    estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
674

675
  # Write the as-run pipeline config to disk.
676
  if run_config.is_chief and save_final_config:
677
    pipeline_config_final = create_pipeline_proto_from_configs(configs)
678
    config_util.save_pipeline_config(pipeline_config_final, estimator.model_dir)
679

680
  return dict(
681
682
      estimator=estimator,
      train_input_fn=train_input_fn,
683
684
      eval_input_fns=eval_input_fns,
      eval_input_names=eval_input_names,
685
      eval_on_train_input_fn=eval_on_train_input_fn,
686
      predict_input_fn=predict_input_fn,
687
      train_steps=train_steps)
688
689
690


def create_train_and_eval_specs(train_input_fn,
691
                                eval_input_fns,
692
                                eval_on_train_input_fn,
693
694
695
696
                                predict_input_fn,
                                train_steps,
                                eval_on_train_data=False,
                                final_exporter_name='Servo',
697
                                eval_spec_names=None):
698
699
700
701
  """Creates a `TrainSpec` and `EvalSpec`s.

  Args:
    train_input_fn: Function that produces features and labels on train data.
702
703
    eval_input_fns: A list of functions that produce features and labels on eval
      data.
704
705
    eval_on_train_input_fn: Function that produces features and labels for
      evaluation on train data.
706
707
708
709
710
    predict_input_fn: Function that produces features for inference.
    train_steps: Number of training steps.
    eval_on_train_data: Whether to evaluate model on training data. Default is
      False.
    final_exporter_name: String name given to `FinalExporter`.
711
    eval_spec_names: A list of string names for each `EvalSpec`.
712
713

  Returns:
714
715
716
    Tuple of `TrainSpec` and list of `EvalSpecs`. If `eval_on_train_data` is
    True, the last `EvalSpec` in the list will correspond to training data. The
    rest EvalSpecs in the list are evaluation datas.
717
718
719
720
  """
  train_spec = tf.estimator.TrainSpec(
      input_fn=train_input_fn, max_steps=train_steps)

721
  if eval_spec_names is None:
722
    eval_spec_names = [str(i) for i in range(len(eval_input_fns))]
723
724

  eval_specs = []
725
726
727
728
729
730
731
732
  for index, (eval_spec_name, eval_input_fn) in enumerate(
      zip(eval_spec_names, eval_input_fns)):
    # Uses final_exporter_name as exporter_name for the first eval spec for
    # backward compatibility.
    if index == 0:
      exporter_name = final_exporter_name
    else:
      exporter_name = '{}_{}'.format(final_exporter_name, eval_spec_name)
733
734
735
736
737
738
739
740
    exporter = tf.estimator.FinalExporter(
        name=exporter_name, serving_input_receiver_fn=predict_input_fn)
    eval_specs.append(
        tf.estimator.EvalSpec(
            name=eval_spec_name,
            input_fn=eval_input_fn,
            steps=None,
            exporters=exporter))
741
742
743
744

  if eval_on_train_data:
    eval_specs.append(
        tf.estimator.EvalSpec(
745
            name='eval_on_train', input_fn=eval_on_train_input_fn, steps=None))
746
747

  return train_spec, eval_specs
748
749


750
def continuous_eval(estimator, model_dir, input_fn, train_steps, name):
751
752
753
754
755
756
757
758
759
760
  """Perform continuous evaluation on checkpoints written to a model directory.

  Args:
    estimator: Estimator object to use for evaluation.
    model_dir: Model directory to read checkpoints for continuous evaluation.
    input_fn: Input function to use for evaluation.
    train_steps: Number of training steps. This is used to infer the last
      checkpoint and stop evaluation loop.
    name: Namescope for eval summary.
  """
761

762
763
764
765
766
767
768
769
770
771
772
  def terminate_eval():
    tf.logging.info('Terminating eval after 180 seconds of no checkpoints')
    return True

  for ckpt in tf.contrib.training.checkpoints_iterator(
      model_dir, min_interval_secs=180, timeout=None,
      timeout_fn=terminate_eval):

    tf.logging.info('Starting Evaluation.')
    try:
      eval_results = estimator.evaluate(
773
          input_fn=input_fn, steps=None, checkpoint_path=ckpt, name=name)
774
775
776
777
778
779
780
781
782
783
784
785
786
787
      tf.logging.info('Eval results: %s' % eval_results)

      # Terminate eval job when final checkpoint is reached
      current_step = int(os.path.basename(ckpt).split('-')[1])
      if current_step >= train_steps:
        tf.logging.info(
            'Evaluation finished after training step %d' % current_step)
        break

    except tf.errors.NotFoundError:
      tf.logging.info(
          'Checkpoint %s no longer exists, skipping checkpoint' % ckpt)


788
789
790
791
792
793
794
795
def populate_experiment(run_config,
                        hparams,
                        pipeline_config_path,
                        train_steps=None,
                        eval_steps=None,
                        model_fn_creator=create_model_fn,
                        **kwargs):
  """Populates an `Experiment` object.
796

797
798
  EXPERIMENT CLASS IS DEPRECATED. Please switch to
  tf.estimator.train_and_evaluate. As an example, see model_main.py.
799

800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
  Args:
    run_config: A `RunConfig`.
    hparams: A `HParams`.
    pipeline_config_path: A path to a pipeline config file.
    train_steps: Number of training steps. If None, the number of training steps
      is set from the `TrainConfig` proto.
    eval_steps: Number of evaluation steps per evaluation cycle. If None, the
      number of evaluation steps is set from the `EvalConfig` proto.
    model_fn_creator: A function that creates a `model_fn` for `Estimator`.
      Follows the signature:

      * Args:
        * `detection_model_fn`: Function that returns `DetectionModel` instance.
        * `configs`: Dictionary of pipeline config objects.
        * `hparams`: `HParams` object.
      * Returns:
        `model_fn` for `Estimator`.

    **kwargs: Additional keyword arguments for configuration override.

  Returns:
    An `Experiment` that defines all aspects of training, evaluation, and
    export.
  """
  tf.logging.warning('Experiment is being deprecated. Please use '
                     'tf.estimator.train_and_evaluate(). See model_main.py for '
                     'an example.')
  train_and_eval_dict = create_estimator_and_inputs(
      run_config,
      hparams,
      pipeline_config_path,
      train_steps=train_steps,
      eval_steps=eval_steps,
      model_fn_creator=model_fn_creator,
834
      save_final_config=True,
835
836
837
      **kwargs)
  estimator = train_and_eval_dict['estimator']
  train_input_fn = train_and_eval_dict['train_input_fn']
838
  eval_input_fns = train_and_eval_dict['eval_input_fns']
839
840
841
842
843
844
845
846
847
848
849
  predict_input_fn = train_and_eval_dict['predict_input_fn']
  train_steps = train_and_eval_dict['train_steps']

  export_strategies = [
      tf.contrib.learn.utils.saved_model_export_utils.make_export_strategy(
          serving_input_fn=predict_input_fn)
  ]

  return tf.contrib.learn.Experiment(
      estimator=estimator,
      train_input_fn=train_input_fn,
850
      eval_input_fn=eval_input_fns[0],
851
      train_steps=train_steps,
852
      eval_steps=None,
853
      export_strategies=export_strategies,
854
855
      eval_delay_secs=120,
  )