exporter.py 21.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions to export object detection inference graph."""
import os
Vivek Rathod's avatar
Vivek Rathod committed
18
import tempfile
19
import tensorflow as tf
20
from tensorflow.contrib.quantize.python import graph_matcher
21
from tensorflow.core.protobuf import saver_pb2
22
from tensorflow.python.tools import freeze_graph  # pylint: disable=g-direct-tensorflow-import
23
from object_detection.builders import graph_rewriter_builder
24
25
26
from object_detection.builders import model_builder
from object_detection.core import standard_fields as fields
from object_detection.data_decoders import tf_example_decoder
27
from object_detection.utils import config_util
28
from object_detection.utils import shape_utils
29
30
31

slim = tf.contrib.slim

32
freeze_graph_with_def_protos = freeze_graph.freeze_graph_with_def_protos
33
34


35
36
37
38
39
40
41
42
def rewrite_nn_resize_op(is_quantized=False):
  """Replaces a custom nearest-neighbor resize op with the Tensorflow version.

  Some graphs use this custom version for TPU-compatibility.

  Args:
    is_quantized: True if the default graph is quantized.
  """
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  def remove_nn():
    """Remove nearest neighbor upsampling structure and replace with TF op."""
    input_pattern = graph_matcher.OpTypePattern(
        'FakeQuantWithMinMaxVars' if is_quantized else '*')
    stack_1_pattern = graph_matcher.OpTypePattern(
        'Pack', inputs=[input_pattern, input_pattern], ordered_inputs=False)
    stack_2_pattern = graph_matcher.OpTypePattern(
        'Pack', inputs=[stack_1_pattern, stack_1_pattern], ordered_inputs=False)
    reshape_pattern = graph_matcher.OpTypePattern(
        'Reshape', inputs=[stack_2_pattern, 'Const'], ordered_inputs=False)
    consumer_pattern = graph_matcher.OpTypePattern(
        'Add|AddV2|Max|Mul', inputs=[reshape_pattern, '*'],
        ordered_inputs=False)

    match_counter = 0
    matcher = graph_matcher.GraphMatcher(consumer_pattern)
    for match in matcher.match_graph(tf.get_default_graph()):
      match_counter += 1
      projection_op = match.get_op(input_pattern)
      reshape_op = match.get_op(reshape_pattern)
      consumer_op = match.get_op(consumer_pattern)
      nn_resize = tf.image.resize_nearest_neighbor(
          projection_op.outputs[0],
          reshape_op.outputs[0].shape.dims[1:3],
          align_corners=False,
          name=os.path.split(reshape_op.name)[0] + '/resize_nearest_neighbor')

      for index, op_input in enumerate(consumer_op.inputs):
        if op_input == reshape_op.outputs[0]:
          consumer_op._update_input(index, nn_resize)  # pylint: disable=protected-access
          break

    tf.logging.info('Found and fixed {} matches'.format(match_counter))
    return match_counter

  # Applying twice because both inputs to Add could be NN pattern
  total_removals = 0
  while remove_nn():
    total_removals += 1
    # This number is chosen based on the nas-fpn architecture.
    if total_removals > 4:
      raise ValueError('Graph removal encountered a infinite loop.')
85
86


Vivek Rathod's avatar
Vivek Rathod committed
87
88
def replace_variable_values_with_moving_averages(graph,
                                                 current_checkpoint_file,
89
90
                                                 new_checkpoint_file,
                                                 no_ema_collection=None):
Vivek Rathod's avatar
Vivek Rathod committed
91
  """Replaces variable values in the checkpoint with their moving averages.
92

Vivek Rathod's avatar
Vivek Rathod committed
93
94
95
96
97
98
99
100
101
  If the current checkpoint has shadow variables maintaining moving averages of
  the variables defined in the graph, this function generates a new checkpoint
  where the variables contain the values of their moving averages.

  Args:
    graph: a tf.Graph object.
    current_checkpoint_file: a checkpoint containing both original variables and
      their moving averages.
    new_checkpoint_file: file path to write a new checkpoint.
102
103
    no_ema_collection: A list of namescope substrings to match the variables
      to eliminate EMA.
Vivek Rathod's avatar
Vivek Rathod committed
104
105
106
107
  """
  with graph.as_default():
    variable_averages = tf.train.ExponentialMovingAverage(0.0)
    ema_variables_to_restore = variable_averages.variables_to_restore()
108
109
    ema_variables_to_restore = config_util.remove_unecessary_ema(
        ema_variables_to_restore, no_ema_collection)
Vivek Rathod's avatar
Vivek Rathod committed
110
111
112
113
114
115
116
117
118
119
120
121
122
    with tf.Session() as sess:
      read_saver = tf.train.Saver(ema_variables_to_restore)
      read_saver.restore(sess, current_checkpoint_file)
      write_saver = tf.train.Saver()
      write_saver.save(sess, new_checkpoint_file)


def _image_tensor_input_placeholder(input_shape=None):
  """Returns input placeholder and a 4-D uint8 image tensor."""
  if input_shape is None:
    input_shape = (None, None, None, 3)
  input_tensor = tf.placeholder(
      dtype=tf.uint8, shape=input_shape, name='image_tensor')
Derek Chow's avatar
Derek Chow committed
123
  return input_tensor, input_tensor
124

125

126
def _tf_example_input_placeholder():
Derek Chow's avatar
Derek Chow committed
127
128
129
  """Returns input that accepts a batch of strings with tf examples.

  Returns:
Vivek Rathod's avatar
Vivek Rathod committed
130
    a tuple of input placeholder and the output decoded images.
Derek Chow's avatar
Derek Chow committed
131
  """
132
133
134
135
136
137
138
  batch_tf_example_placeholder = tf.placeholder(
      tf.string, shape=[None], name='tf_example')
  def decode(tf_example_string_tensor):
    tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
        tf_example_string_tensor)
    image_tensor = tensor_dict[fields.InputDataFields.image]
    return image_tensor
Derek Chow's avatar
Derek Chow committed
139
  return (batch_tf_example_placeholder,
140
141
142
143
144
145
          shape_utils.static_or_dynamic_map_fn(
              decode,
              elems=batch_tf_example_placeholder,
              dtype=tf.uint8,
              parallel_iterations=32,
              back_prop=False))
146
147


148
def _encoded_image_string_tensor_input_placeholder():
Derek Chow's avatar
Derek Chow committed
149
150
151
  """Returns input that accepts a batch of PNG or JPEG strings.

  Returns:
Vivek Rathod's avatar
Vivek Rathod committed
152
    a tuple of input placeholder and the output decoded images.
Derek Chow's avatar
Derek Chow committed
153
  """
154
155
156
157
158
159
160
161
162
  batch_image_str_placeholder = tf.placeholder(
      dtype=tf.string,
      shape=[None],
      name='encoded_image_string_tensor')
  def decode(encoded_image_string_tensor):
    image_tensor = tf.image.decode_image(encoded_image_string_tensor,
                                         channels=3)
    image_tensor.set_shape((None, None, 3))
    return image_tensor
Derek Chow's avatar
Derek Chow committed
163
164
165
166
167
168
169
  return (batch_image_str_placeholder,
          tf.map_fn(
              decode,
              elems=batch_image_str_placeholder,
              dtype=tf.uint8,
              parallel_iterations=32,
              back_prop=False))
170
171


172
input_placeholder_fn_map = {
173
174
175
    'image_tensor': _image_tensor_input_placeholder,
    'encoded_image_string_tensor':
    _encoded_image_string_tensor_input_placeholder,
176
177
178
179
    'tf_example': _tf_example_input_placeholder,
}


180
181
def add_output_tensor_nodes(postprocessed_tensors,
                            output_collection_name='inference_op'):
182
183
184
185
186
187
188
189
  """Adds output nodes for detection boxes and scores.

  Adds the following nodes for output tensors -
    * num_detections: float32 tensor of shape [batch_size].
    * detection_boxes: float32 tensor of shape [batch_size, num_boxes, 4]
      containing detected boxes.
    * detection_scores: float32 tensor of shape [batch_size, num_boxes]
      containing scores for the detected boxes.
190
191
192
    * detection_multiclass_scores: (Optional) float32 tensor of shape
      [batch_size, num_boxes, num_classes_with_background] for containing class
      score distribution for detected boxes including background if any.
pkulzc's avatar
pkulzc committed
193
194
195
196
    * detection_features: (Optional) float32 tensor of shape
      [batch, num_boxes, roi_height, roi_width, depth]
      containing classifier features
      for each detected box
197
198
    * detection_classes: float32 tensor of shape [batch_size, num_boxes]
      containing class predictions for the detected boxes.
199
200
201
    * detection_keypoints: (Optional) float32 tensor of shape
      [batch_size, num_boxes, num_keypoints, 2] containing keypoints for each
      detection box.
202
203
204
    * detection_masks: (Optional) float32 tensor of shape
      [batch_size, num_boxes, mask_height, mask_width] containing masks for each
      detection box.
205
206
207
208
209

  Args:
    postprocessed_tensors: a dictionary containing the following fields
      'detection_boxes': [batch, max_detections, 4]
      'detection_scores': [batch, max_detections]
210
211
      'detection_multiclass_scores': [batch, max_detections,
        num_classes_with_background]
pkulzc's avatar
pkulzc committed
212
      'detection_features': [batch, num_boxes, roi_height, roi_width, depth]
213
      'detection_classes': [batch, max_detections]
214
215
      'detection_masks': [batch, max_detections, mask_height, mask_width]
        (optional).
216
217
      'detection_keypoints': [batch, max_detections, num_keypoints, 2]
        (optional).
218
      'num_detections': [batch]
219
    output_collection_name: Name of collection to add output tensors to.
220
221
222

  Returns:
    A tensor dict containing the added output tensor nodes.
223
  """
224
  detection_fields = fields.DetectionResultFields
225
  label_id_offset = 1
226
227
  boxes = postprocessed_tensors.get(detection_fields.detection_boxes)
  scores = postprocessed_tensors.get(detection_fields.detection_scores)
228
229
  multiclass_scores = postprocessed_tensors.get(
      detection_fields.detection_multiclass_scores)
pkulzc's avatar
pkulzc committed
230
231
  box_classifier_features = postprocessed_tensors.get(
      detection_fields.detection_features)
232
233
  raw_boxes = postprocessed_tensors.get(detection_fields.raw_detection_boxes)
  raw_scores = postprocessed_tensors.get(detection_fields.raw_detection_scores)
234
235
  classes = postprocessed_tensors.get(
      detection_fields.detection_classes) + label_id_offset
236
  keypoints = postprocessed_tensors.get(detection_fields.detection_keypoints)
237
238
  masks = postprocessed_tensors.get(detection_fields.detection_masks)
  num_detections = postprocessed_tensors.get(detection_fields.num_detections)
239
  outputs = {}
240
241
242
243
  outputs[detection_fields.detection_boxes] = tf.identity(
      boxes, name=detection_fields.detection_boxes)
  outputs[detection_fields.detection_scores] = tf.identity(
      scores, name=detection_fields.detection_scores)
244
245
246
  if multiclass_scores is not None:
    outputs[detection_fields.detection_multiclass_scores] = tf.identity(
        multiclass_scores, name=detection_fields.detection_multiclass_scores)
pkulzc's avatar
pkulzc committed
247
248
249
250
  if box_classifier_features is not None:
    outputs[detection_fields.detection_features] = tf.identity(
        box_classifier_features,
        name=detection_fields.detection_features)
251
252
253
254
  outputs[detection_fields.detection_classes] = tf.identity(
      classes, name=detection_fields.detection_classes)
  outputs[detection_fields.num_detections] = tf.identity(
      num_detections, name=detection_fields.num_detections)
255
256
257
258
259
260
  if raw_boxes is not None:
    outputs[detection_fields.raw_detection_boxes] = tf.identity(
        raw_boxes, name=detection_fields.raw_detection_boxes)
  if raw_scores is not None:
    outputs[detection_fields.raw_detection_scores] = tf.identity(
        raw_scores, name=detection_fields.raw_detection_scores)
261
262
263
  if keypoints is not None:
    outputs[detection_fields.detection_keypoints] = tf.identity(
        keypoints, name=detection_fields.detection_keypoints)
264
  if masks is not None:
265
266
    outputs[detection_fields.detection_masks] = tf.identity(
        masks, name=detection_fields.detection_masks)
267
268
  for output_key in outputs:
    tf.add_to_collection(output_collection_name, outputs[output_key])
269

270
  return outputs
271
272


273
274
275
276
def write_saved_model(saved_model_path,
                      frozen_graph_def,
                      inputs,
                      outputs):
277
278
279
280
281
282
283
284
285
  """Writes SavedModel to disk.

  If checkpoint_path is not None bakes the weights into the graph thereby
  eliminating the need of checkpoint files during inference. If the model
  was trained with moving averages, setting use_moving_averages to true
  restores the moving averages, otherwise the original set of variables
  is restored.

  Args:
286
287
    saved_model_path: Path to write SavedModel.
    frozen_graph_def: tf.GraphDef holding frozen graph.
288
    inputs: The input placeholder tensor.
289
290
291
    outputs: A tensor dictionary containing the outputs of a DetectionModel.
  """
  with tf.Graph().as_default():
292
    with tf.Session() as sess:
293

294
      tf.import_graph_def(frozen_graph_def, name='')
295

296
      builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path)
297
298
299
300
301
302
303
304
305
306
307

      tensor_info_inputs = {
          'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
      tensor_info_outputs = {}
      for k, v in outputs.items():
        tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)

      detection_signature = (
          tf.saved_model.signature_def_utils.build_signature_def(
              inputs=tensor_info_inputs,
              outputs=tensor_info_outputs,
308
309
              method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
          ))
310
311

      builder.add_meta_graph_and_variables(
312
313
          sess,
          [tf.saved_model.tag_constants.SERVING],
314
          signature_def_map={
315
316
              tf.saved_model.signature_constants
              .DEFAULT_SERVING_SIGNATURE_DEF_KEY:
317
318
319
320
321
322
                  detection_signature,
          },
      )
      builder.save()


323
324
325
326
327
def write_graph_and_checkpoint(inference_graph_def,
                               model_path,
                               input_saver_def,
                               trained_checkpoint_prefix):
  """Writes the graph and the checkpoint into disk."""
328
329
330
331
  for node in inference_graph_def.node:
    node.device = ''
  with tf.Graph().as_default():
    tf.import_graph_def(inference_graph_def, name='')
332
333
334
    with tf.Session() as sess:
      saver = tf.train.Saver(
          saver_def=input_saver_def, save_relative_paths=True)
335
336
337
338
      saver.restore(sess, trained_checkpoint_prefix)
      saver.save(sess, model_path)


339
340
def _get_outputs_from_inputs(input_tensors, detection_model,
                             output_collection_name):
341
  inputs = tf.cast(input_tensors, dtype=tf.float32)
342
343
344
345
346
  preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)
  output_tensors = detection_model.predict(
      preprocessed_inputs, true_image_shapes)
  postprocessed_tensors = detection_model.postprocess(
      output_tensors, true_image_shapes)
347
348
  return add_output_tensor_nodes(postprocessed_tensors,
                                 output_collection_name)
349
350


351
352
def build_detection_graph(input_type, detection_model, input_shape,
                          output_collection_name, graph_hook_fn):
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
  """Build the detection graph."""
  if input_type not in input_placeholder_fn_map:
    raise ValueError('Unknown input type: {}'.format(input_type))
  placeholder_args = {}
  if input_shape is not None:
    if input_type != 'image_tensor':
      raise ValueError('Can only specify input shape for `image_tensor` '
                       'inputs.')
    placeholder_args['input_shape'] = input_shape
  placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
      **placeholder_args)
  outputs = _get_outputs_from_inputs(
      input_tensors=input_tensors,
      detection_model=detection_model,
      output_collection_name=output_collection_name)

  # Add global step to the graph.
  slim.get_or_create_global_step()

  if graph_hook_fn: graph_hook_fn()

  return outputs, placeholder_tensor


377
378
379
def _export_inference_graph(input_type,
                            detection_model,
                            use_moving_averages,
380
381
                            trained_checkpoint_prefix,
                            output_directory,
Vivek Rathod's avatar
Vivek Rathod committed
382
383
                            additional_output_tensor_names=None,
                            input_shape=None,
384
                            output_collection_name='inference_op',
385
                            graph_hook_fn=None,
386
387
                            write_inference_graph=False,
                            temp_checkpoint_prefix=''):
388
  """Export helper."""
389
390
391
392
393
394
  tf.gfile.MakeDirs(output_directory)
  frozen_graph_path = os.path.join(output_directory,
                                   'frozen_inference_graph.pb')
  saved_model_path = os.path.join(output_directory, 'saved_model')
  model_path = os.path.join(output_directory, 'model.ckpt')

395
  outputs, placeholder_tensor = build_detection_graph(
396
397
398
399
400
      input_type=input_type,
      detection_model=detection_model,
      input_shape=input_shape,
      output_collection_name=output_collection_name,
      graph_hook_fn=graph_hook_fn)
401

402
  profile_inference_graph(tf.get_default_graph())
403
  saver_kwargs = {}
404
  if use_moving_averages:
405
406
407
408
409
410
411
    if not temp_checkpoint_prefix:
      # This check is to be compatible with both version of SaverDef.
      if os.path.isfile(trained_checkpoint_prefix):
        saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
        temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name
      else:
        temp_checkpoint_prefix = tempfile.mkdtemp()
Vivek Rathod's avatar
Vivek Rathod committed
412
413
    replace_variable_values_with_moving_averages(
        tf.get_default_graph(), trained_checkpoint_prefix,
414
415
        temp_checkpoint_prefix)
    checkpoint_to_use = temp_checkpoint_prefix
416
  else:
Vivek Rathod's avatar
Vivek Rathod committed
417
418
    checkpoint_to_use = trained_checkpoint_prefix

419
  saver = tf.train.Saver(**saver_kwargs)
420
421
  input_saver_def = saver.as_saver_def()

422
  write_graph_and_checkpoint(
423
424
425
      inference_graph_def=tf.get_default_graph().as_graph_def(),
      model_path=model_path,
      input_saver_def=input_saver_def,
Vivek Rathod's avatar
Vivek Rathod committed
426
      trained_checkpoint_prefix=checkpoint_to_use)
427
428
429
430
431
432
  if write_inference_graph:
    inference_graph_def = tf.get_default_graph().as_graph_def()
    inference_graph_path = os.path.join(output_directory,
                                        'inference_graph.pbtxt')
    for node in inference_graph_def.node:
      node.device = ''
433
    with tf.gfile.GFile(inference_graph_path, 'wb') as f:
434
      f.write(str(inference_graph_def))
Vivek Rathod's avatar
Vivek Rathod committed
435
436
437
438
439

  if additional_output_tensor_names is not None:
    output_node_names = ','.join(outputs.keys()+additional_output_tensor_names)
  else:
    output_node_names = ','.join(outputs.keys())
440

441
  frozen_graph_def = freeze_graph.freeze_graph_with_def_protos(
442
443
      input_graph_def=tf.get_default_graph().as_graph_def(),
      input_saver_def=input_saver_def,
Vivek Rathod's avatar
Vivek Rathod committed
444
445
      input_checkpoint=checkpoint_to_use,
      output_node_names=output_node_names,
446
447
      restore_op_name='save/restore_all',
      filename_tensor_name='save/Const:0',
448
      output_graph=frozen_graph_path,
449
450
      clear_devices=True,
      initializer_nodes='')
451

452
453
  write_saved_model(saved_model_path, frozen_graph_def,
                    placeholder_tensor, outputs)
454
455


456
457
458
459
def export_inference_graph(input_type,
                           pipeline_config,
                           trained_checkpoint_prefix,
                           output_directory,
Vivek Rathod's avatar
Vivek Rathod committed
460
461
                           input_shape=None,
                           output_collection_name='inference_op',
462
463
                           additional_output_tensor_names=None,
                           write_inference_graph=False):
464
465
466
  """Exports inference graph for the model specified in the pipeline config.

  Args:
467
468
    input_type: Type of input for the graph. Can be one of ['image_tensor',
      'encoded_image_string_tensor', 'tf_example'].
469
    pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
470
471
    trained_checkpoint_prefix: Path to the trained checkpoint file.
    output_directory: Path to write outputs.
Vivek Rathod's avatar
Vivek Rathod committed
472
473
    input_shape: Sets a fixed shape for an `image_tensor` input. If not
      specified, will default to [None, None, None, 3].
474
475
    output_collection_name: Name of collection to add output tensors to.
      If None, does not add output tensors to a collection.
Vivek Rathod's avatar
Vivek Rathod committed
476
    additional_output_tensor_names: list of additional output
477
      tensors to include in the frozen graph.
478
    write_inference_graph: If true, writes inference graph to disk.
479
480
481
  """
  detection_model = model_builder.build(pipeline_config.model,
                                        is_training=False)
482
483
484
485
486
  graph_rewriter_fn = None
  if pipeline_config.HasField('graph_rewriter'):
    graph_rewriter_config = pipeline_config.graph_rewriter
    graph_rewriter_fn = graph_rewriter_builder.build(graph_rewriter_config,
                                                     is_training=False)
487
488
489
490
491
492
493
494
495
  _export_inference_graph(
      input_type,
      detection_model,
      pipeline_config.eval_config.use_moving_averages,
      trained_checkpoint_prefix,
      output_directory,
      additional_output_tensor_names,
      input_shape,
      output_collection_name,
496
      graph_hook_fn=graph_rewriter_fn,
497
      write_inference_graph=write_inference_graph)
498
  pipeline_config.eval_config.use_moving_averages = False
499
  config_util.save_pipeline_config(pipeline_config, output_directory)
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530


def profile_inference_graph(graph):
  """Profiles the inference graph.

  Prints model parameters and computation FLOPs given an inference graph.
  BatchNorms are excluded from the parameter count due to the fact that
  BatchNorms are usually folded. BatchNorm, Initializer, Regularizer
  and BiasAdd are not considered in FLOP count.

  Args:
    graph: the inference graph.
  """
  tfprof_vars_option = (
      tf.contrib.tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
  tfprof_flops_option = tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS

  # Batchnorm is usually folded during inference.
  tfprof_vars_option['trim_name_regexes'] = ['.*BatchNorm.*']
  # Initializer and Regularizer are only used in training.
  tfprof_flops_option['trim_name_regexes'] = [
      '.*BatchNorm.*', '.*Initializer.*', '.*Regularizer.*', '.*BiasAdd.*'
  ]

  tf.contrib.tfprof.model_analyzer.print_model_analysis(
      graph,
      tfprof_options=tfprof_vars_option)

  tf.contrib.tfprof.model_analyzer.print_model_analysis(
      graph,
      tfprof_options=tfprof_flops_option)