model_lib.py 49.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
r"""Constructs model, inputs, and training environment."""
16
17
18
19
20

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

21
import copy
22
import functools
23
import os
24

25
import tensorflow.compat.v1 as tf
26
from tensorflow.compat.v1 import estimator as tf_estimator
27
import tensorflow.compat.v2 as tf2
28
29
import tf_slim as slim

30
from object_detection import eval_util
31
from object_detection import exporter as exporter_lib
32
from object_detection import inputs
33
from object_detection.builders import graph_rewriter_builder
34
35
36
37
38
from object_detection.builders import model_builder
from object_detection.builders import optimizer_builder
from object_detection.core import standard_fields as fields
from object_detection.utils import config_util
from object_detection.utils import label_map_util
39
from object_detection.utils import ops
40
41
42
43
from object_detection.utils import shape_utils
from object_detection.utils import variables_helper
from object_detection.utils import visualization_utils as vis_utils

44
45
46
47
48
49
50
51
# pylint: disable=g-import-not-at-top
try:
  from tensorflow.contrib import learn as contrib_learn
except ImportError:
  # TF 2.0 doesn't ship with contrib.
  pass
# pylint: enable=g-import-not-at-top

52
53
54
55
56
57
58
59
# A map of names to methods that help build the model.
MODEL_BUILD_UTIL_MAP = {
    'get_configs_from_pipeline_file':
        config_util.get_configs_from_pipeline_file,
    'create_pipeline_proto_from_configs':
        config_util.create_pipeline_proto_from_configs,
    'merge_external_params_with_configs':
        config_util.merge_external_params_with_configs,
60
61
62
63
64
65
    'create_train_input_fn':
        inputs.create_train_input_fn,
    'create_eval_input_fn':
        inputs.create_eval_input_fn,
    'create_predict_input_fn':
        inputs.create_predict_input_fn,
66
67
    'detection_model_fn_base':
        model_builder.build,
68
69
70
}


71
72
def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
                                  max_number_of_boxes):
73
  """Extracts groundtruth data from detection_model and prepares it for eval.
74
75
76
77

  Args:
    detection_model: A `DetectionModel` object.
    class_agnostic: Whether the detections are class_agnostic.
78
    max_number_of_boxes: Max number of groundtruth boxes.
79
80
81
82

  Returns:
    A tuple of:
    groundtruth: Dictionary with the following fields:
83
84
85
86
87
      'groundtruth_boxes': [batch_size, num_boxes, 4] float32 tensor of boxes,
        in normalized coordinates.
      'groundtruth_classes': [batch_size, num_boxes] int64 tensor of 1-indexed
        classes.
      'groundtruth_masks': 4D float32 tensor of instance masks (if provided in
88
        groundtruth)
89
90
      'groundtruth_is_crowd': [batch_size, num_boxes] bool tensor indicating
        is_crowd annotations (if provided in groundtruth).
91
92
93
      'groundtruth_area': [batch_size, num_boxes] float32 tensor indicating
        the area (in the original absolute coordinates) of annotations (if
        provided in groundtruth).
94
95
      'num_groundtruth_boxes': [batch_size] tensor containing the maximum number
        of groundtruth boxes per image..
96
97
      'groundtruth_keypoints': [batch_size, num_boxes, num_keypoints, 2] float32
        tensor of keypoints (if provided in groundtruth).
98
99
100
101
102
103
104
105
106
      'groundtruth_dp_num_points_list': [batch_size, num_boxes] int32 tensor
        with the number of DensePose points for each instance (if provided in
        groundtruth).
      'groundtruth_dp_part_ids_list': [batch_size, num_boxes,
        max_sampled_points] int32 tensor with the part ids for each DensePose
        sampled point (if provided in groundtruth).
      'groundtruth_dp_surface_coords_list': [batch_size, num_boxes,
        max_sampled_points, 4] containing the DensePose surface coordinates for
        each sampled point (if provided in groundtruth).
107
108
      'groundtruth_track_ids_list': [batch_size, num_boxes] int32 tensor
        with track ID for each instance (if provided in groundtruth).
109
110
111
112
      'groundtruth_group_of': [batch_size, num_boxes] bool tensor indicating
        group_of annotations (if provided in groundtruth).
      'groundtruth_labeled_classes': [batch_size, num_classes] int64
        tensor of 1-indexed classes.
113
114
115
116
117
118
      'groundtruth_verified_neg_classes': [batch_size, num_classes] float32
        K-hot representation of 1-indexed classes which were verified as not
        present in the image.
      'groundtruth_not_exhaustive_classes': [batch_size, num_classes] K-hot
        representation of 1-indexed classes which don't have all of their
        instances marked exhaustively.
119
120
121
122
      'input_data_fields.groundtruth_image_classes': integer representation of
        the classes that were sent for verification for a given image. Note that
        this field does not support batching as the number of classes can be
        variable.
123
124
125
    class_agnostic: Boolean indicating whether detections are class agnostic.
  """
  input_data_fields = fields.InputDataFields()
126
127
128
  groundtruth_boxes = tf.stack(
      detection_model.groundtruth_lists(fields.BoxListFields.boxes))
  groundtruth_boxes_shape = tf.shape(groundtruth_boxes)
129
130
131
  # For class-agnostic models, groundtruth one-hot encodings collapse to all
  # ones.
  if class_agnostic:
132
133
    groundtruth_classes_one_hot = tf.ones(
        [groundtruth_boxes_shape[0], groundtruth_boxes_shape[1], 1])
134
  else:
135
136
    groundtruth_classes_one_hot = tf.stack(
        detection_model.groundtruth_lists(fields.BoxListFields.classes))
137
138
  label_id_offset = 1  # Applying label id offset (b/63711816)
  groundtruth_classes = (
139
      tf.argmax(groundtruth_classes_one_hot, axis=2) + label_id_offset)
140
141
142
143
  groundtruth = {
      input_data_fields.groundtruth_boxes: groundtruth_boxes,
      input_data_fields.groundtruth_classes: groundtruth_classes
  }
144

145
146
147
148
149
150
151
  if detection_model.groundtruth_has_field(
      input_data_fields.groundtruth_image_classes):
    groundtruth_image_classes_k_hot = tf.stack(
        detection_model.groundtruth_lists(
            input_data_fields.groundtruth_image_classes))
    groundtruth_image_classes = tf.expand_dims(
        tf.where(groundtruth_image_classes_k_hot > 0)[:, 1], 0)
152
153
    # Adds back label_id_offset as it is subtracted in
    # convert_labeled_classes_to_k_hot.
154
    groundtruth[
155
156
        input_data_fields.
        groundtruth_image_classes] = groundtruth_image_classes + label_id_offset
157

158
  if detection_model.groundtruth_has_field(fields.BoxListFields.masks):
159
160
161
    groundtruth[input_data_fields.groundtruth_instance_masks] = tf.stack(
        detection_model.groundtruth_lists(fields.BoxListFields.masks))

162
  if detection_model.groundtruth_has_field(fields.BoxListFields.is_crowd):
163
164
165
    groundtruth[input_data_fields.groundtruth_is_crowd] = tf.stack(
        detection_model.groundtruth_lists(fields.BoxListFields.is_crowd))

166
167
168
169
170
171
172
173
  if detection_model.groundtruth_has_field(input_data_fields.groundtruth_area):
    groundtruth[input_data_fields.groundtruth_area] = tf.stack(
        detection_model.groundtruth_lists(input_data_fields.groundtruth_area))

  if detection_model.groundtruth_has_field(fields.BoxListFields.keypoints):
    groundtruth[input_data_fields.groundtruth_keypoints] = tf.stack(
        detection_model.groundtruth_lists(fields.BoxListFields.keypoints))

174
175
176
177
178
179
180
181
182
  if detection_model.groundtruth_has_field(
      fields.BoxListFields.keypoint_depths):
    groundtruth[input_data_fields.groundtruth_keypoint_depths] = tf.stack(
        detection_model.groundtruth_lists(fields.BoxListFields.keypoint_depths))
    groundtruth[
        input_data_fields.groundtruth_keypoint_depth_weights] = tf.stack(
            detection_model.groundtruth_lists(
                fields.BoxListFields.keypoint_depth_weights))

183
184
185
186
187
188
  if detection_model.groundtruth_has_field(
      fields.BoxListFields.keypoint_visibilities):
    groundtruth[input_data_fields.groundtruth_keypoint_visibilities] = tf.stack(
        detection_model.groundtruth_lists(
            fields.BoxListFields.keypoint_visibilities))

189
190
191
192
  if detection_model.groundtruth_has_field(fields.BoxListFields.group_of):
    groundtruth[input_data_fields.groundtruth_group_of] = tf.stack(
        detection_model.groundtruth_lists(fields.BoxListFields.group_of))

193
  label_id_offset_paddings = tf.constant([[0, 0], [1, 0]])
194
  if detection_model.groundtruth_has_field(
195
      input_data_fields.groundtruth_verified_neg_classes):
196
    groundtruth[input_data_fields.groundtruth_verified_neg_classes] = tf.pad(
197
198
199
        tf.stack(
            detection_model.groundtruth_lists(
                input_data_fields.groundtruth_verified_neg_classes)),
200
        label_id_offset_paddings)
201
202
203

  if detection_model.groundtruth_has_field(
      input_data_fields.groundtruth_not_exhaustive_classes):
204
205
206
    groundtruth[input_data_fields.groundtruth_not_exhaustive_classes] = tf.pad(
        tf.stack(
            detection_model.groundtruth_lists(
207
                input_data_fields.groundtruth_not_exhaustive_classes)),
208
        label_id_offset_paddings)
209

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
  if detection_model.groundtruth_has_field(
      fields.BoxListFields.densepose_num_points):
    groundtruth[input_data_fields.groundtruth_dp_num_points] = tf.stack(
        detection_model.groundtruth_lists(
            fields.BoxListFields.densepose_num_points))
  if detection_model.groundtruth_has_field(
      fields.BoxListFields.densepose_part_ids):
    groundtruth[input_data_fields.groundtruth_dp_part_ids] = tf.stack(
        detection_model.groundtruth_lists(
            fields.BoxListFields.densepose_part_ids))
  if detection_model.groundtruth_has_field(
      fields.BoxListFields.densepose_surface_coords):
    groundtruth[input_data_fields.groundtruth_dp_surface_coords] = tf.stack(
        detection_model.groundtruth_lists(
            fields.BoxListFields.densepose_surface_coords))
225
226
227
228
229

  if detection_model.groundtruth_has_field(fields.BoxListFields.track_ids):
    groundtruth[input_data_fields.groundtruth_track_ids] = tf.stack(
        detection_model.groundtruth_lists(fields.BoxListFields.track_ids))

230
231
  if detection_model.groundtruth_has_field(
      input_data_fields.groundtruth_labeled_classes):
232
233
234
235
236
    groundtruth[input_data_fields.groundtruth_labeled_classes] = tf.pad(
        tf.stack(
            detection_model.groundtruth_lists(
                input_data_fields.groundtruth_labeled_classes)),
        label_id_offset_paddings)
237

238
239
  groundtruth[input_data_fields.num_groundtruth_boxes] = (
      tf.tile([max_number_of_boxes], multiples=[groundtruth_boxes_shape[0]]))
240
241
242
243
244
245
246
  return groundtruth


def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
  """Unstacks all tensors in `tensor_dict` along 0th dimension.

  Unstacks tensor from the tensor dict along 0th dimension and returns a
247
  tensor_dict containing values that are lists of unstacked, unpadded tensors.
248
249
250
251
252
253

  Tensors in the `tensor_dict` are expected to be of one of the three shapes:
  1. [batch_size]
  2. [batch_size, height, width, channels]
  3. [batch_size, num_boxes, d1, d2, ... dn]

254
255
  When unpad_groundtruth_tensors is set to true, unstacked tensors of form 3
  above are sliced along the `num_boxes` dimension using the value in tensor
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
  field.InputDataFields.num_groundtruth_boxes.

  Note that this function has a static list of input data fields and has to be
  kept in sync with the InputDataFields defined in core/standard_fields.py

  Args:
    tensor_dict: A dictionary of batched groundtruth tensors.
    unpad_groundtruth_tensors: Whether to remove padding along `num_boxes`
      dimension of the groundtruth tensors.

  Returns:
    A dictionary where the keys are from fields.InputDataFields and values are
    a list of unstacked (optionally unpadded) tensors.

  Raises:
    ValueError: If unpad_tensors is True and `tensor_dict` does not contain
      `num_groundtruth_boxes` tensor.
  """
274
275
276
  unbatched_tensor_dict = {
      key: tf.unstack(tensor) for key, tensor in tensor_dict.items()
  }
277
  if unpad_groundtruth_tensors:
278
279
    if (fields.InputDataFields.num_groundtruth_boxes
        not in unbatched_tensor_dict):
280
281
282
283
284
285
286
287
288
      raise ValueError('`num_groundtruth_boxes` not found in tensor_dict. '
                       'Keys available: {}'.format(
                           unbatched_tensor_dict.keys()))
    unbatched_unpadded_tensor_dict = {}
    unpad_keys = set([
        # List of input data fields that are padded along the num_boxes
        # dimension. This list has to be kept in sync with InputDataFields in
        # standard_fields.py.
        fields.InputDataFields.groundtruth_instance_masks,
289
        fields.InputDataFields.groundtruth_instance_mask_weights,
290
291
292
        fields.InputDataFields.groundtruth_classes,
        fields.InputDataFields.groundtruth_boxes,
        fields.InputDataFields.groundtruth_keypoints,
293
294
        fields.InputDataFields.groundtruth_keypoint_depths,
        fields.InputDataFields.groundtruth_keypoint_depth_weights,
295
        fields.InputDataFields.groundtruth_keypoint_visibilities,
296
297
298
        fields.InputDataFields.groundtruth_dp_num_points,
        fields.InputDataFields.groundtruth_dp_part_ids,
        fields.InputDataFields.groundtruth_dp_surface_coords,
299
        fields.InputDataFields.groundtruth_track_ids,
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        fields.InputDataFields.groundtruth_group_of,
        fields.InputDataFields.groundtruth_difficult,
        fields.InputDataFields.groundtruth_is_crowd,
        fields.InputDataFields.groundtruth_area,
        fields.InputDataFields.groundtruth_weights
    ]).intersection(set(unbatched_tensor_dict.keys()))

    for key in unpad_keys:
      unpadded_tensor_list = []
      for num_gt, padded_tensor in zip(
          unbatched_tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
          unbatched_tensor_dict[key]):
        tensor_shape = shape_utils.combined_static_and_dynamic_shape(
            padded_tensor)
        slice_begin = tf.zeros([len(tensor_shape)], dtype=tf.int32)
        slice_size = tf.stack(
            [num_gt] + [-1 if dim is None else dim for dim in tensor_shape[1:]])
        unpadded_tensor = tf.slice(padded_tensor, slice_begin, slice_size)
        unpadded_tensor_list.append(unpadded_tensor)
      unbatched_unpadded_tensor_dict[key] = unpadded_tensor_list
320

321
322
323
324
325
    unbatched_tensor_dict.update(unbatched_unpadded_tensor_dict)

  return unbatched_tensor_dict


326
def provide_groundtruth(model, labels, training_step=None):
327
328
329
330
331
332
333
334
335
  """Provides the labels to a model as groundtruth.

  This helper function extracts the corresponding boxes, classes,
  keypoints, weights, masks, etc. from the labels, and provides it
  as groundtruth to the models.

  Args:
    model: The detection model to provide groundtruth to.
    labels: The labels for the training or evaluation inputs.
336
337
    training_step: int, optional. The training step for the model. Useful for
      models which want to anneal loss weights.
338
339
340
341
342
  """
  gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
  gt_classes_list = labels[fields.InputDataFields.groundtruth_classes]
  gt_masks_list = None
  if fields.InputDataFields.groundtruth_instance_masks in labels:
343
    gt_masks_list = labels[fields.InputDataFields.groundtruth_instance_masks]
344
345
346
347
  gt_mask_weights_list = None
  if fields.InputDataFields.groundtruth_instance_mask_weights in labels:
    gt_mask_weights_list = labels[
        fields.InputDataFields.groundtruth_instance_mask_weights]
348
349
350
  gt_keypoints_list = None
  if fields.InputDataFields.groundtruth_keypoints in labels:
    gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
351
352
353
354
355
356
357
  gt_keypoint_depths_list = None
  gt_keypoint_depth_weights_list = None
  if fields.InputDataFields.groundtruth_keypoint_depths in labels:
    gt_keypoint_depths_list = (
        labels[fields.InputDataFields.groundtruth_keypoint_depths])
    gt_keypoint_depth_weights_list = (
        labels[fields.InputDataFields.groundtruth_keypoint_depth_weights])
358
359
360
361
  gt_keypoint_visibilities_list = None
  if fields.InputDataFields.groundtruth_keypoint_visibilities in labels:
    gt_keypoint_visibilities_list = labels[
        fields.InputDataFields.groundtruth_keypoint_visibilities]
362
363
364
365
366
367
  gt_dp_num_points_list = None
  if fields.InputDataFields.groundtruth_dp_num_points in labels:
    gt_dp_num_points_list = labels[
        fields.InputDataFields.groundtruth_dp_num_points]
  gt_dp_part_ids_list = None
  if fields.InputDataFields.groundtruth_dp_part_ids in labels:
368
    gt_dp_part_ids_list = labels[fields.InputDataFields.groundtruth_dp_part_ids]
369
370
371
372
  gt_dp_surface_coords_list = None
  if fields.InputDataFields.groundtruth_dp_surface_coords in labels:
    gt_dp_surface_coords_list = labels[
        fields.InputDataFields.groundtruth_dp_surface_coords]
373
374
  gt_track_ids_list = None
  if fields.InputDataFields.groundtruth_track_ids in labels:
375
    gt_track_ids_list = labels[fields.InputDataFields.groundtruth_track_ids]
376
377
378
379
380
  gt_weights_list = None
  if fields.InputDataFields.groundtruth_weights in labels:
    gt_weights_list = labels[fields.InputDataFields.groundtruth_weights]
  gt_confidences_list = None
  if fields.InputDataFields.groundtruth_confidences in labels:
381
    gt_confidences_list = labels[fields.InputDataFields.groundtruth_confidences]
382
383
384
  gt_is_crowd_list = None
  if fields.InputDataFields.groundtruth_is_crowd in labels:
    gt_is_crowd_list = labels[fields.InputDataFields.groundtruth_is_crowd]
385
386
387
  gt_group_of_list = None
  if fields.InputDataFields.groundtruth_group_of in labels:
    gt_group_of_list = labels[fields.InputDataFields.groundtruth_group_of]
388
389
390
391
392
393
394
  gt_area_list = None
  if fields.InputDataFields.groundtruth_area in labels:
    gt_area_list = labels[fields.InputDataFields.groundtruth_area]
  gt_labeled_classes = None
  if fields.InputDataFields.groundtruth_labeled_classes in labels:
    gt_labeled_classes = labels[
        fields.InputDataFields.groundtruth_labeled_classes]
395
396
397
398
399
400
401
402
  gt_verified_neg_classes = None
  if fields.InputDataFields.groundtruth_verified_neg_classes in labels:
    gt_verified_neg_classes = labels[
        fields.InputDataFields.groundtruth_verified_neg_classes]
  gt_not_exhaustive_classes = None
  if fields.InputDataFields.groundtruth_not_exhaustive_classes in labels:
    gt_not_exhaustive_classes = labels[
        fields.InputDataFields.groundtruth_not_exhaustive_classes]
403
404
405
406
  groundtruth_image_classes = None
  if fields.InputDataFields.groundtruth_image_classes in labels:
    groundtruth_image_classes = labels[
        fields.InputDataFields.groundtruth_image_classes]
407
408
409
410
  model.provide_groundtruth(
      groundtruth_boxes_list=gt_boxes_list,
      groundtruth_classes_list=gt_classes_list,
      groundtruth_confidences_list=gt_confidences_list,
411
      groundtruth_labeled_classes=gt_labeled_classes,
412
      groundtruth_masks_list=gt_masks_list,
413
      groundtruth_mask_weights_list=gt_mask_weights_list,
414
      groundtruth_keypoints_list=gt_keypoints_list,
415
      groundtruth_keypoint_visibilities_list=gt_keypoint_visibilities_list,
416
417
418
      groundtruth_dp_num_points_list=gt_dp_num_points_list,
      groundtruth_dp_part_ids_list=gt_dp_part_ids_list,
      groundtruth_dp_surface_coords_list=gt_dp_surface_coords_list,
419
      groundtruth_weights_list=gt_weights_list,
420
      groundtruth_is_crowd_list=gt_is_crowd_list,
421
      groundtruth_group_of_list=gt_group_of_list,
422
      groundtruth_area_list=gt_area_list,
423
424
      groundtruth_track_ids_list=gt_track_ids_list,
      groundtruth_verified_neg_classes=gt_verified_neg_classes,
425
426
      groundtruth_not_exhaustive_classes=gt_not_exhaustive_classes,
      groundtruth_keypoint_depths_list=gt_keypoint_depths_list,
427
      groundtruth_keypoint_depth_weights_list=gt_keypoint_depth_weights_list,
428
      groundtruth_image_classes=groundtruth_image_classes,
429
      training_step=training_step)
430
431


432
433
434
435
def create_model_fn(detection_model_fn,
                    configs,
                    hparams=None,
                    use_tpu=False,
436
                    postprocess_on_cpu=False):
437
438
439
440
441
442
  """Creates a model function for `Estimator`.

  Args:
    detection_model_fn: Function that returns a `DetectionModel` instance.
    configs: Dictionary of pipeline config objects.
    hparams: `HParams` object.
443
444
    use_tpu: Boolean indicating whether model should be constructed for use on
      TPU.
445
    postprocess_on_cpu: When use_tpu and postprocess_on_cpu is true, postprocess
446
      is scheduled on the host cpu.
447
448
449
450
451
452

  Returns:
    `model_fn` for `Estimator`.
  """
  train_config = configs['train_config']
  eval_input_config = configs['eval_input_config']
453
  eval_config = configs['eval_config']
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470

  def model_fn(features, labels, mode, params=None):
    """Constructs the object detection model.

    Args:
      features: Dictionary of feature tensors, returned from `input_fn`.
      labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
        otherwise None.
      mode: Mode key from tf.estimator.ModeKeys.
      params: Parameter dictionary passed from the estimator.

    Returns:
      An `EstimatorSpec` that encapsulates the model and its serving
        configurations.
    """
    params = params or {}
    total_loss, train_op, detections, export_outputs = None, None, None, None
471
    is_training = mode == tf_estimator.ModeKeys.TRAIN
472
473
474
475

    # Make sure to set the Keras learning phase. True during training,
    # False for inference.
    tf.keras.backend.set_learning_phase(is_training)
476
477
478
    # Set policy for mixed-precision training with Keras-based models.
    if use_tpu and train_config.use_bfloat16:
      # Enable v2 behavior, as `mixed_bfloat16` is only supported in TF 2.0.
479
      tf.keras.layers.enable_v2_dtype_behavior()
480
      tf2.keras.mixed_precision.set_global_policy('mixed_bfloat16')
481
482
    detection_model = detection_model_fn(
        is_training=is_training, add_summaries=(not use_tpu))
483
484
    scaffold_fn = None

485
    if mode == tf_estimator.ModeKeys.TRAIN:
486
487
488
      labels = unstack_batch(
          labels,
          unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors)
489
    elif mode == tf_estimator.ModeKeys.EVAL:
490
491
492
      # For evaling on train data, it is necessary to check whether groundtruth
      # must be unpadded.
      boxes_shape = (
493
494
          labels[
              fields.InputDataFields.groundtruth_boxes].get_shape().as_list())
495
      unpad_groundtruth_tensors = boxes_shape[1] is not None and not use_tpu
496
497
      labels = unstack_batch(
          labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
498

499
    if mode in (tf_estimator.ModeKeys.TRAIN, tf_estimator.ModeKeys.EVAL):
pkulzc's avatar
pkulzc committed
500
      provide_groundtruth(detection_model, labels)
501
502

    preprocessed_images = features[fields.InputDataFields.image]
503
504
505

    side_inputs = detection_model.get_side_inputs(features)

506
    if use_tpu and train_config.use_bfloat16:
507
      with tf.tpu.bfloat16_scope():
508
509
        prediction_dict = detection_model.predict(
            preprocessed_images,
510
            features[fields.InputDataFields.true_image_shape], **side_inputs)
511
        prediction_dict = ops.bfloat16_to_float32_nested(prediction_dict)
512
513
514
    else:
      prediction_dict = detection_model.predict(
          preprocessed_images,
515
          features[fields.InputDataFields.true_image_shape], **side_inputs)
516
517
518
519

    def postprocess_wrapper(args):
      return detection_model.postprocess(args[0], args[1])

520
    if mode in (tf_estimator.ModeKeys.EVAL, tf_estimator.ModeKeys.PREDICT):
521
      if use_tpu and postprocess_on_cpu:
522
        detections = tf.tpu.outside_compilation(
523
524
525
526
            postprocess_wrapper,
            (prediction_dict,
             features[fields.InputDataFields.true_image_shape]))
      else:
527
528
529
        detections = postprocess_wrapper(
            (prediction_dict,
             features[fields.InputDataFields.true_image_shape]))
530

531
    if mode == tf_estimator.ModeKeys.TRAIN:
532
533
      load_pretrained = hparams.load_pretrained if hparams else False
      if train_config.fine_tune_checkpoint and load_pretrained:
534
535
536
537
538
539
540
541
        if not train_config.fine_tune_checkpoint_type:
          # train_config.from_detection_checkpoint field is deprecated. For
          # backward compatibility, set train_config.fine_tune_checkpoint_type
          # based on train_config.from_detection_checkpoint.
          if train_config.from_detection_checkpoint:
            train_config.fine_tune_checkpoint_type = 'detection'
          else:
            train_config.fine_tune_checkpoint_type = 'classification'
542
        asg_map = detection_model.restore_map(
543
            fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
544
545
546
547
            load_all_detection_checkpoint_vars=(
                train_config.load_all_detection_checkpoint_vars))
        available_var_map = (
            variables_helper.get_variables_available_in_checkpoint(
548
549
                asg_map,
                train_config.fine_tune_checkpoint,
550
551
                include_global_step=False))
        if use_tpu:
552

553
554
555
556
          def tpu_scaffold():
            tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                          available_var_map)
            return tf.train.Scaffold()
557

558
559
560
561
562
          scaffold_fn = tpu_scaffold
        else:
          tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                        available_var_map)

563
564
    if mode in (tf_estimator.ModeKeys.TRAIN, tf_estimator.ModeKeys.EVAL):
      if (mode == tf_estimator.ModeKeys.EVAL and
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
          eval_config.use_dummy_loss_in_eval):
        total_loss = tf.constant(1.0)
        losses_dict = {'Loss/total_loss': total_loss}
      else:
        losses_dict = detection_model.loss(
            prediction_dict, features[fields.InputDataFields.true_image_shape])
        losses = [loss_tensor for loss_tensor in losses_dict.values()]
        if train_config.add_regularization_loss:
          regularization_losses = detection_model.regularization_losses()
          if use_tpu and train_config.use_bfloat16:
            regularization_losses = ops.bfloat16_to_float32_nested(
                regularization_losses)
          if regularization_losses:
            regularization_loss = tf.add_n(
                regularization_losses, name='regularization_loss')
            losses.append(regularization_loss)
            losses_dict['Loss/regularization_loss'] = regularization_loss
        total_loss = tf.add_n(losses, name='total_loss')
        losses_dict['Loss/total_loss'] = total_loss
584

585
586
587
588
589
      if 'graph_rewriter_config' in configs:
        graph_rewriter_fn = graph_rewriter_builder.build(
            configs['graph_rewriter_config'], is_training=is_training)
        graph_rewriter_fn()

590
591
      # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we
      # can write learning rate summaries on TPU without host calls.
592
593
594
595
      global_step = tf.train.get_or_create_global_step()
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)

596
    if mode == tf_estimator.ModeKeys.TRAIN:
597
      if use_tpu:
598
        training_optimizer = tf.tpu.CrossShardOptimizer(training_optimizer)
599
600
601

      # Optionally freeze some layers by setting their gradients to be zero.
      trainable_variables = None
602
603
604
605
606
607
      include_variables = (
          train_config.update_trainable_variables
          if train_config.update_trainable_variables else None)
      exclude_variables = (
          train_config.freeze_variables
          if train_config.freeze_variables else None)
608
      trainable_variables = slim.filter_variables(
609
610
611
          tf.trainable_variables(),
          include_patterns=include_variables,
          exclude_patterns=exclude_variables)
612
613
614
615
616
617
618
619
620

      clip_gradients_value = None
      if train_config.gradient_clipping_by_norm > 0:
        clip_gradients_value = train_config.gradient_clipping_by_norm

      if not use_tpu:
        for var in optimizer_summary_vars:
          tf.summary.scalar(var.op.name, var)
      summaries = [] if use_tpu else None
621
622
      if train_config.summarize_gradients:
        summaries = ['gradients', 'gradient_norm', 'global_gradient_norm']
623
      train_op = slim.optimizers.optimize_loss(
624
625
626
627
628
          loss=total_loss,
          global_step=global_step,
          learning_rate=None,
          clip_gradients=clip_gradients_value,
          optimizer=training_optimizer,
629
          update_ops=detection_model.updates(),
630
631
632
633
          variables=trainable_variables,
          summaries=summaries,
          name='')  # Preventing scope prefix on all variables.

634
    if mode == tf_estimator.ModeKeys.PREDICT:
635
      exported_output = exporter_lib.add_output_tensor_nodes(detections)
636
637
      export_outputs = {
          tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
638
              tf_estimator.export.PredictOutput(exported_output)
639
640
641
      }

    eval_metric_ops = None
642
    scaffold = None
643
    if mode == tf_estimator.ModeKeys.EVAL:
644
645
      class_agnostic = (
          fields.DetectionResultFields.detection_classes not in detections)
646
647
648
      groundtruth = _prepare_groundtruth_for_eval(
          detection_model, class_agnostic,
          eval_input_config.max_number_of_boxes)
649
      use_original_images = fields.InputDataFields.original_image in features
pkulzc's avatar
pkulzc committed
650
      if use_original_images:
651
652
653
        eval_images = features[fields.InputDataFields.original_image]
        true_image_shapes = tf.slice(
            features[fields.InputDataFields.true_image_shape], [0, 0], [-1, 3])
654
655
        original_image_spatial_shapes = features[
            fields.InputDataFields.original_image_spatial_shape]
pkulzc's avatar
pkulzc committed
656
657
      else:
        eval_images = features[fields.InputDataFields.image]
658
659
        true_image_shapes = None
        original_image_spatial_shapes = None
pkulzc's avatar
pkulzc committed
660

661
662
663
      eval_dict = eval_util.result_dict_for_batched_example(
          eval_images,
          features[inputs.HASH_KEY],
664
665
666
          detections,
          groundtruth,
          class_agnostic=class_agnostic,
667
668
669
          scale_to_absolute=True,
          original_image_spatial_shapes=original_image_spatial_shapes,
          true_image_shapes=true_image_shapes)
670

671
672
673
674
      if fields.InputDataFields.image_additional_channels in features:
        eval_dict[fields.InputDataFields.image_additional_channels] = features[
            fields.InputDataFields.image_additional_channels]

675
676
677
678
679
      if class_agnostic:
        category_index = label_map_util.create_class_agnostic_category_index()
      else:
        category_index = label_map_util.create_category_index_from_labelmap(
            eval_input_config.label_map_path)
680
      vis_metric_ops = None
681
      if not use_tpu and use_original_images:
682
683
        keypoint_edges = [(kp.start, kp.end) for kp in eval_config.keypoint_edge
                         ]
684

685
686
687
688
689
        eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections(
            category_index,
            max_examples_to_draw=eval_config.num_visualizations,
            max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
            min_score_thresh=eval_config.min_score_threshold,
690
691
            use_normalized_coordinates=False,
            keypoint_edges=keypoint_edges or None)
692
693
        vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops(
            eval_dict)
694

695
696
      # Eval metrics on a single example.
      eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
DefineFC's avatar
DefineFC committed
697
          eval_config, list(category_index.values()), eval_dict)
698
699
700
701
      for loss_key, loss_tensor in iter(losses_dict.items()):
        eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
      for var in optimizer_summary_vars:
        eval_metric_ops[var.op.name] = (var, tf.no_op())
702
703
      if vis_metric_ops is not None:
        eval_metric_ops.update(vis_metric_ops)
704
      eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}
705

706
707
708
709
710
711
712
713
714
715
      if eval_config.use_moving_averages:
        variable_averages = tf.train.ExponentialMovingAverage(0.0)
        variables_to_restore = variable_averages.variables_to_restore()
        keep_checkpoint_every_n_hours = (
            train_config.keep_checkpoint_every_n_hours)
        saver = tf.train.Saver(
            variables_to_restore,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
        scaffold = tf.train.Scaffold(saver=saver)

716
    # EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
717
718
    if use_tpu and mode != tf_estimator.ModeKeys.EVAL:
      return tf_estimator.tpu.TPUEstimatorSpec(
719
720
721
722
723
724
725
726
          mode=mode,
          scaffold_fn=scaffold_fn,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metrics=eval_metric_ops,
          export_outputs=export_outputs)
    else:
727
728
729
730
731
732
733
734
735
      if scaffold is None:
        keep_checkpoint_every_n_hours = (
            train_config.keep_checkpoint_every_n_hours)
        saver = tf.train.Saver(
            sharded=True,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
            save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        scaffold = tf.train.Scaffold(saver=saver)
736
      return tf_estimator.EstimatorSpec(
737
738
739
740
741
          mode=mode,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metric_ops=eval_metric_ops,
742
743
          export_outputs=export_outputs,
          scaffold=scaffold)
744
745
746
747

  return model_fn


748
def create_estimator_and_inputs(run_config,
749
750
                                hparams=None,
                                pipeline_config_path=None,
751
                                config_override=None,
752
                                train_steps=None,
753
                                sample_1_of_n_eval_examples=1,
754
                                sample_1_of_n_eval_on_train_examples=1,
755
756
757
758
759
                                model_fn_creator=create_model_fn,
                                use_tpu_estimator=False,
                                use_tpu=False,
                                num_shards=1,
                                params=None,
760
                                override_eval_num_epochs=True,
761
                                save_final_config=False,
762
763
                                postprocess_on_cpu=False,
                                export_to_tpu=None,
764
765
                                **kwargs):
  """Creates `Estimator`, input functions, and steps.
766
767
768

  Args:
    run_config: A `RunConfig`.
769
    hparams: (optional) A `HParams`.
770
    pipeline_config_path: A path to a pipeline config file.
771
772
    config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to
      override the config from `pipeline_config_path`.
773
774
    train_steps: Number of training steps. If None, the number of training steps
      is set from the `TrainConfig` proto.
775
776
777
778
779
    sample_1_of_n_eval_examples: Integer representing how often an eval example
      should be sampled. If 1, will sample all examples.
    sample_1_of_n_eval_on_train_examples: Similar to
      `sample_1_of_n_eval_examples`, except controls the sampling of training
      data for evaluation.
780
781
782
783
784
785
    model_fn_creator: A function that creates a `model_fn` for `Estimator`.
      Follows the signature:
      * Args:
        * `detection_model_fn`: Function that returns `DetectionModel` instance.
        * `configs`: Dictionary of pipeline config objects.
        * `hparams`: `HParams` object.
786
787
788
      * Returns: `model_fn` for `Estimator`.
    use_tpu_estimator: Whether a `TPUEstimator` should be returned. If False, an
      `Estimator` will be returned.
789
790
791
792
793
794
    use_tpu: Boolean, whether training and evaluation should run on TPU. Only
      used if `use_tpu_estimator` is True.
    num_shards: Number of shards (TPU cores). Only used if `use_tpu_estimator`
      is True.
    params: Parameter dictionary passed from the estimator. Only used if
      `use_tpu_estimator` is True.
795
796
    override_eval_num_epochs: Whether to overwrite the number of epochs to 1 for
      eval_input.
797
798
    save_final_config: Whether to save final config (obtained after applying
      overrides) to `estimator.model_dir`.
799
800
801
802
803
    postprocess_on_cpu: When use_tpu and postprocess_on_cpu are true,
      postprocess is scheduled on the host cpu.
    export_to_tpu: When use_tpu and export_to_tpu are true,
      `export_savedmodel()` exports a metagraph for serving on TPU besides the
      one on CPU.
804
805
806
    **kwargs: Additional keyword arguments for configuration override.

  Returns:
807
808
809
    A dictionary with the following fields:
    'estimator': An `Estimator` or `TPUEstimator`.
    'train_input_fn': A training input function.
810
811
    'eval_input_fns': A list of all evaluation input functions.
    'eval_input_names': A list of names for each evaluation input.
812
    'eval_on_train_input_fn': An evaluation-on-train input function.
813
814
815
    'predict_input_fn': A prediction input function.
    'train_steps': Number of training steps. Either directly from input or from
      configuration.
816
  """
817
818
819
820
  get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
      'get_configs_from_pipeline_file']
  merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
      'merge_external_params_with_configs']
821
822
  create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
      'create_pipeline_proto_from_configs']
823
824
825
  create_train_input_fn = MODEL_BUILD_UTIL_MAP['create_train_input_fn']
  create_eval_input_fn = MODEL_BUILD_UTIL_MAP['create_eval_input_fn']
  create_predict_input_fn = MODEL_BUILD_UTIL_MAP['create_predict_input_fn']
826
  detection_model_fn_base = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']
827

828
829
  configs = get_configs_from_pipeline_file(
      pipeline_config_path, config_override=config_override)
830
831
  kwargs.update({
      'train_steps': train_steps,
832
      'use_bfloat16': configs['train_config'].use_bfloat16 and use_tpu
833
  })
pkulzc's avatar
pkulzc committed
834
  if sample_1_of_n_eval_examples >= 1:
835
    kwargs.update({'sample_1_of_n_eval_examples': sample_1_of_n_eval_examples})
836
837
838
839
  if override_eval_num_epochs:
    kwargs.update({'eval_num_epochs': 1})
    tf.logging.warning(
        'Forced number of epochs for all eval validations to be 1.')
840
  configs = merge_external_params_with_configs(
841
      configs, hparams, kwargs_dict=kwargs)
842
843
844
845
  model_config = configs['model']
  train_config = configs['train_config']
  train_input_config = configs['train_input_config']
  eval_config = configs['eval_config']
846
847
848
849
850
851
852
853
854
855
856
  eval_input_configs = configs['eval_input_configs']
  eval_on_train_input_config = copy.deepcopy(train_input_config)
  eval_on_train_input_config.sample_1_of_n_examples = (
      sample_1_of_n_eval_on_train_examples)
  if override_eval_num_epochs and eval_on_train_input_config.num_epochs != 1:
    tf.logging.warning('Expected number of evaluation epochs is 1, but '
                       'instead encountered `eval_on_train_input_config'
                       '.num_epochs` = '
                       '{}. Overwriting `num_epochs` to 1.'.format(
                           eval_on_train_input_config.num_epochs))
    eval_on_train_input_config.num_epochs = 1
857

858
859
860
  # update train_steps from config but only when non-zero value is provided
  if train_steps is None and train_config.num_steps != 0:
    train_steps = train_config.num_steps
861
862

  detection_model_fn = functools.partial(
863
      detection_model_fn_base, model_config=model_config)
864

865
  # Create the input functions for TRAIN/EVAL/PREDICT.
866
  train_input_fn = create_train_input_fn(
867
868
869
      train_config=train_config,
      train_input_config=train_input_config,
      model_config=model_config)
870
871
872
873
874
875
876
877
  eval_input_fns = []
  for eval_input_config in eval_input_configs:
    eval_input_fns.append(
        create_eval_input_fn(
            eval_config=eval_config,
            eval_input_config=eval_input_config,
            model_config=model_config))

878
879
880
  eval_input_names = [
      eval_input_config.name for eval_input_config in eval_input_configs
  ]
881
882
  eval_on_train_input_fn = create_eval_input_fn(
      eval_config=eval_config,
883
      eval_input_config=eval_on_train_input_config,
884
      model_config=model_config)
885
  predict_input_fn = create_predict_input_fn(
886
      model_config=model_config, predict_input_config=eval_input_configs[0])
887

888
  # Read export_to_tpu from hparams if not passed.
889
  if export_to_tpu is None and hparams is not None:
890
    export_to_tpu = hparams.get('export_to_tpu', False)
891
892
  tf.logging.info('create_estimator_and_inputs: use_tpu %s, export_to_tpu %s',
                  use_tpu, export_to_tpu)
893
894
  model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu,
                              postprocess_on_cpu)
895
  if use_tpu_estimator:
896
    estimator = tf_estimator.tpu.TPUEstimator(
897
898
899
900
901
902
        model_fn=model_fn,
        train_batch_size=train_config.batch_size,
        # For each core, only batch size 1 is supported for eval.
        eval_batch_size=num_shards * 1 if use_tpu else 1,
        use_tpu=use_tpu,
        config=run_config,
903
904
        export_to_tpu=export_to_tpu,
        eval_on_tpu=False,  # Eval runs on CPU, so disable eval on TPU
pkulzc's avatar
pkulzc committed
905
        params=params if params else {})
906
  else:
907
    estimator = tf_estimator.Estimator(model_fn=model_fn, config=run_config)
908

909
  # Write the as-run pipeline config to disk.
910
  if run_config.is_chief and save_final_config:
911
    pipeline_config_final = create_pipeline_proto_from_configs(configs)
912
    config_util.save_pipeline_config(pipeline_config_final, estimator.model_dir)
913

914
  return dict(
915
916
      estimator=estimator,
      train_input_fn=train_input_fn,
917
918
      eval_input_fns=eval_input_fns,
      eval_input_names=eval_input_names,
919
      eval_on_train_input_fn=eval_on_train_input_fn,
920
      predict_input_fn=predict_input_fn,
921
      train_steps=train_steps)
922
923
924


def create_train_and_eval_specs(train_input_fn,
925
                                eval_input_fns,
926
                                eval_on_train_input_fn,
927
928
929
930
                                predict_input_fn,
                                train_steps,
                                eval_on_train_data=False,
                                final_exporter_name='Servo',
931
                                eval_spec_names=None):
932
933
934
935
  """Creates a `TrainSpec` and `EvalSpec`s.

  Args:
    train_input_fn: Function that produces features and labels on train data.
936
937
    eval_input_fns: A list of functions that produce features and labels on eval
      data.
938
939
    eval_on_train_input_fn: Function that produces features and labels for
      evaluation on train data.
940
941
942
943
944
    predict_input_fn: Function that produces features for inference.
    train_steps: Number of training steps.
    eval_on_train_data: Whether to evaluate model on training data. Default is
      False.
    final_exporter_name: String name given to `FinalExporter`.
945
    eval_spec_names: A list of string names for each `EvalSpec`.
946
947

  Returns:
948
949
950
    Tuple of `TrainSpec` and list of `EvalSpecs`. If `eval_on_train_data` is
    True, the last `EvalSpec` in the list will correspond to training data. The
    rest EvalSpecs in the list are evaluation datas.
951
  """
952
  train_spec = tf_estimator.TrainSpec(
953
954
      input_fn=train_input_fn, max_steps=train_steps)

955
  if eval_spec_names is None:
956
    eval_spec_names = [str(i) for i in range(len(eval_input_fns))]
957
958

  eval_specs = []
959
960
  for index, (eval_spec_name,
              eval_input_fn) in enumerate(zip(eval_spec_names, eval_input_fns)):
961
962
963
964
965
966
    # Uses final_exporter_name as exporter_name for the first eval spec for
    # backward compatibility.
    if index == 0:
      exporter_name = final_exporter_name
    else:
      exporter_name = '{}_{}'.format(final_exporter_name, eval_spec_name)
967
    exporter = tf_estimator.FinalExporter(
968
969
        name=exporter_name, serving_input_receiver_fn=predict_input_fn)
    eval_specs.append(
970
        tf_estimator.EvalSpec(
971
972
973
974
            name=eval_spec_name,
            input_fn=eval_input_fn,
            steps=None,
            exporters=exporter))
975
976
977

  if eval_on_train_data:
    eval_specs.append(
978
        tf_estimator.EvalSpec(
979
            name='eval_on_train', input_fn=eval_on_train_input_fn, steps=None))
980
981

  return train_spec, eval_specs
982
983


984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
def _evaluate_checkpoint(estimator,
                         input_fn,
                         checkpoint_path,
                         name,
                         max_retries=0):
  """Evaluates a checkpoint.

  Args:
    estimator: Estimator object to use for evaluation.
    input_fn: Input function to use for evaluation.
    checkpoint_path: Path of the checkpoint to evaluate.
    name: Namescope for eval summary.
    max_retries: Maximum number of times to retry the evaluation on encountering
      a tf.errors.InvalidArgumentError. If negative, will always retry the
      evaluation.

  Returns:
    Estimator evaluation results.
  """
  always_retry = True if max_retries < 0 else False
  retries = 0
  while always_retry or retries <= max_retries:
    try:
      return estimator.evaluate(
          input_fn=input_fn,
          steps=None,
          checkpoint_path=checkpoint_path,
          name=name)
    except tf.errors.InvalidArgumentError as e:
      if always_retry or retries < max_retries:
        tf.logging.info('Retrying checkpoint evaluation after exception: %s', e)
        retries += 1
      else:
        raise e


1020
1021
1022
1023
1024
1025
def continuous_eval_generator(estimator,
                              model_dir,
                              input_fn,
                              train_steps,
                              name,
                              max_retries=0):
1026
1027
1028
1029
1030
1031
1032
1033
1034
  """Perform continuous evaluation on checkpoints written to a model directory.

  Args:
    estimator: Estimator object to use for evaluation.
    model_dir: Model directory to read checkpoints for continuous evaluation.
    input_fn: Input function to use for evaluation.
    train_steps: Number of training steps. This is used to infer the last
      checkpoint and stop evaluation loop.
    name: Namescope for eval summary.
1035
1036
1037
    max_retries: Maximum number of times to retry the evaluation on encountering
      a tf.errors.InvalidArgumentError. If negative, will always retry the
      evaluation.
1038
1039
1040

  Yields:
    Pair of current step and eval_results.
1041
  """
1042

1043
1044
1045
1046
  def terminate_eval():
    tf.logging.info('Terminating eval after 180 seconds of no checkpoints')
    return True

1047
  for ckpt in tf.train.checkpoints_iterator(
1048
1049
1050
1051
1052
      model_dir, min_interval_secs=180, timeout=None,
      timeout_fn=terminate_eval):

    tf.logging.info('Starting Evaluation.')
    try:
1053
1054
1055
1056
1057
1058
      eval_results = _evaluate_checkpoint(
          estimator=estimator,
          input_fn=input_fn,
          checkpoint_path=ckpt,
          name=name,
          max_retries=max_retries)
1059
1060
1061
1062
      tf.logging.info('Eval results: %s' % eval_results)

      # Terminate eval job when final checkpoint is reached
      current_step = int(os.path.basename(ckpt).split('-')[1])
1063
      yield (current_step, eval_results)
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
      if current_step >= train_steps:
        tf.logging.info(
            'Evaluation finished after training step %d' % current_step)
        break

    except tf.errors.NotFoundError:
      tf.logging.info(
          'Checkpoint %s no longer exists, skipping checkpoint' % ckpt)


1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
def continuous_eval(estimator,
                    model_dir,
                    input_fn,
                    train_steps,
                    name,
                    max_retries=0):
  """Performs continuous evaluation on checkpoints written to a model directory.

  Args:
    estimator: Estimator object to use for evaluation.
    model_dir: Model directory to read checkpoints for continuous evaluation.
    input_fn: Input function to use for evaluation.
    train_steps: Number of training steps. This is used to infer the last
      checkpoint and stop evaluation loop.
    name: Namescope for eval summary.
    max_retries: Maximum number of times to retry the evaluation on encountering
      a tf.errors.InvalidArgumentError. If negative, will always retry the
      evaluation.
  """
  for current_step, eval_results in continuous_eval_generator(
      estimator, model_dir, input_fn, train_steps, name, max_retries):
    tf.logging.info('Step %s, Eval results: %s', current_step, eval_results)


1098
1099
1100
1101
1102
1103
1104
1105
def populate_experiment(run_config,
                        hparams,
                        pipeline_config_path,
                        train_steps=None,
                        eval_steps=None,
                        model_fn_creator=create_model_fn,
                        **kwargs):
  """Populates an `Experiment` object.
1106

1107
1108
  EXPERIMENT CLASS IS DEPRECATED. Please switch to
  tf.estimator.train_and_evaluate. As an example, see model_main.py.
1109

1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
  Args:
    run_config: A `RunConfig`.
    hparams: A `HParams`.
    pipeline_config_path: A path to a pipeline config file.
    train_steps: Number of training steps. If None, the number of training steps
      is set from the `TrainConfig` proto.
    eval_steps: Number of evaluation steps per evaluation cycle. If None, the
      number of evaluation steps is set from the `EvalConfig` proto.
    model_fn_creator: A function that creates a `model_fn` for `Estimator`.
      Follows the signature:
      * Args:
        * `detection_model_fn`: Function that returns `DetectionModel` instance.
        * `configs`: Dictionary of pipeline config objects.
        * `hparams`: `HParams` object.
1124
      * Returns: `model_fn` for `Estimator`.
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
    **kwargs: Additional keyword arguments for configuration override.

  Returns:
    An `Experiment` that defines all aspects of training, evaluation, and
    export.
  """
  tf.logging.warning('Experiment is being deprecated. Please use '
                     'tf.estimator.train_and_evaluate(). See model_main.py for '
                     'an example.')
  train_and_eval_dict = create_estimator_and_inputs(
      run_config,
      hparams,
      pipeline_config_path,
      train_steps=train_steps,
      eval_steps=eval_steps,
      model_fn_creator=model_fn_creator,
1141
      save_final_config=True,
1142
1143
1144
      **kwargs)
  estimator = train_and_eval_dict['estimator']
  train_input_fn = train_and_eval_dict['train_input_fn']
1145
  eval_input_fns = train_and_eval_dict['eval_input_fns']
1146
1147
1148
1149
  predict_input_fn = train_and_eval_dict['predict_input_fn']
  train_steps = train_and_eval_dict['train_steps']

  export_strategies = [
1150
      contrib_learn.utils.saved_model_export_utils.make_export_strategy(
1151
1152
1153
          serving_input_fn=predict_input_fn)
  ]

1154
  return contrib_learn.Experiment(
1155
1156
      estimator=estimator,
      train_input_fn=train_input_fn,
1157
      eval_input_fn=eval_input_fns[0],
1158
      train_steps=train_steps,
1159
      eval_steps=None,
1160
      export_strategies=export_strategies,
1161
1162
      eval_delay_secs=120,
  )