exporter.py 19.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions to export object detection inference graph."""
import os
Vivek Rathod's avatar
Vivek Rathod committed
18
import tempfile
19
import tensorflow as tf
20
from tensorflow.contrib.quantize.python import graph_matcher
21
from tensorflow.core.protobuf import saver_pb2
22
23
from tensorflow.python.client import session
from tensorflow.python.platform import gfile
24
from tensorflow.python.saved_model import signature_constants
25
from tensorflow.python.tools import freeze_graph
26
from tensorflow.python.training import saver as saver_lib
27
from object_detection.builders import graph_rewriter_builder
28
29
30
from object_detection.builders import model_builder
from object_detection.core import standard_fields as fields
from object_detection.data_decoders import tf_example_decoder
31
from object_detection.utils import config_util
32
from object_detection.utils import shape_utils
33
34
35

slim = tf.contrib.slim

36
freeze_graph_with_def_protos = freeze_graph.freeze_graph_with_def_protos
37
38


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def rewrite_nn_resize_op(is_quantized=False):
  """Replaces a custom nearest-neighbor resize op with the Tensorflow version.

  Some graphs use this custom version for TPU-compatibility.

  Args:
    is_quantized: True if the default graph is quantized.
  """
  input_pattern = graph_matcher.OpTypePattern(
      'FakeQuantWithMinMaxVars' if is_quantized else '*')
  reshape_1_pattern = graph_matcher.OpTypePattern(
      'Reshape', inputs=[input_pattern, 'Const'], ordered_inputs=False)
  mul_pattern = graph_matcher.OpTypePattern(
      'Mul', inputs=[reshape_1_pattern, 'Const'], ordered_inputs=False)
  # The quantization script may or may not insert a fake quant op after the
  # Mul. In either case, these min/max vars are not needed once replaced with
  # the TF version of NN resize.
  fake_quant_pattern = graph_matcher.OpTypePattern(
      'FakeQuantWithMinMaxVars',
      inputs=[mul_pattern, 'Identity', 'Identity'],
      ordered_inputs=False)
  reshape_2_pattern = graph_matcher.OpTypePattern(
      'Reshape',
      inputs=[graph_matcher.OneofPattern([fake_quant_pattern, mul_pattern]),
              'Const'],
      ordered_inputs=False)
  add_pattern = graph_matcher.OpTypePattern(
      'Add', inputs=[reshape_2_pattern, '*'], ordered_inputs=False)

  matcher = graph_matcher.GraphMatcher(add_pattern)
  for match in matcher.match_graph(tf.get_default_graph()):
    projection_op = match.get_op(input_pattern)
    reshape_2_op = match.get_op(reshape_2_pattern)
    add_op = match.get_op(add_pattern)
    nn_resize = tf.image.resize_nearest_neighbor(
        projection_op.outputs[0],
        add_op.outputs[0].shape.dims[1:3],
        align_corners=False)

    for index, op_input in enumerate(add_op.inputs):
      if op_input == reshape_2_op.outputs[0]:
        add_op._update_input(index, nn_resize)  # pylint: disable=protected-access
        break


Vivek Rathod's avatar
Vivek Rathod committed
84
85
86
87
def replace_variable_values_with_moving_averages(graph,
                                                 current_checkpoint_file,
                                                 new_checkpoint_file):
  """Replaces variable values in the checkpoint with their moving averages.
88

Vivek Rathod's avatar
Vivek Rathod committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
  If the current checkpoint has shadow variables maintaining moving averages of
  the variables defined in the graph, this function generates a new checkpoint
  where the variables contain the values of their moving averages.

  Args:
    graph: a tf.Graph object.
    current_checkpoint_file: a checkpoint containing both original variables and
      their moving averages.
    new_checkpoint_file: file path to write a new checkpoint.
  """
  with graph.as_default():
    variable_averages = tf.train.ExponentialMovingAverage(0.0)
    ema_variables_to_restore = variable_averages.variables_to_restore()
    with tf.Session() as sess:
      read_saver = tf.train.Saver(ema_variables_to_restore)
      read_saver.restore(sess, current_checkpoint_file)
      write_saver = tf.train.Saver()
      write_saver.save(sess, new_checkpoint_file)


def _image_tensor_input_placeholder(input_shape=None):
  """Returns input placeholder and a 4-D uint8 image tensor."""
  if input_shape is None:
    input_shape = (None, None, None, 3)
  input_tensor = tf.placeholder(
      dtype=tf.uint8, shape=input_shape, name='image_tensor')
Derek Chow's avatar
Derek Chow committed
115
  return input_tensor, input_tensor
116

117

118
def _tf_example_input_placeholder():
Derek Chow's avatar
Derek Chow committed
119
120
121
  """Returns input that accepts a batch of strings with tf examples.

  Returns:
Vivek Rathod's avatar
Vivek Rathod committed
122
    a tuple of input placeholder and the output decoded images.
Derek Chow's avatar
Derek Chow committed
123
  """
124
125
126
127
128
129
130
  batch_tf_example_placeholder = tf.placeholder(
      tf.string, shape=[None], name='tf_example')
  def decode(tf_example_string_tensor):
    tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
        tf_example_string_tensor)
    image_tensor = tensor_dict[fields.InputDataFields.image]
    return image_tensor
Derek Chow's avatar
Derek Chow committed
131
  return (batch_tf_example_placeholder,
132
133
134
135
136
137
          shape_utils.static_or_dynamic_map_fn(
              decode,
              elems=batch_tf_example_placeholder,
              dtype=tf.uint8,
              parallel_iterations=32,
              back_prop=False))
138
139


140
def _encoded_image_string_tensor_input_placeholder():
Derek Chow's avatar
Derek Chow committed
141
142
143
  """Returns input that accepts a batch of PNG or JPEG strings.

  Returns:
Vivek Rathod's avatar
Vivek Rathod committed
144
    a tuple of input placeholder and the output decoded images.
Derek Chow's avatar
Derek Chow committed
145
  """
146
147
148
149
150
151
152
153
154
  batch_image_str_placeholder = tf.placeholder(
      dtype=tf.string,
      shape=[None],
      name='encoded_image_string_tensor')
  def decode(encoded_image_string_tensor):
    image_tensor = tf.image.decode_image(encoded_image_string_tensor,
                                         channels=3)
    image_tensor.set_shape((None, None, 3))
    return image_tensor
Derek Chow's avatar
Derek Chow committed
155
156
157
158
159
160
161
  return (batch_image_str_placeholder,
          tf.map_fn(
              decode,
              elems=batch_image_str_placeholder,
              dtype=tf.uint8,
              parallel_iterations=32,
              back_prop=False))
162
163


164
input_placeholder_fn_map = {
165
166
167
    'image_tensor': _image_tensor_input_placeholder,
    'encoded_image_string_tensor':
    _encoded_image_string_tensor_input_placeholder,
168
169
170
171
    'tf_example': _tf_example_input_placeholder,
}


172
173
def add_output_tensor_nodes(postprocessed_tensors,
                            output_collection_name='inference_op'):
174
175
176
177
178
179
180
181
182
183
  """Adds output nodes for detection boxes and scores.

  Adds the following nodes for output tensors -
    * num_detections: float32 tensor of shape [batch_size].
    * detection_boxes: float32 tensor of shape [batch_size, num_boxes, 4]
      containing detected boxes.
    * detection_scores: float32 tensor of shape [batch_size, num_boxes]
      containing scores for the detected boxes.
    * detection_classes: float32 tensor of shape [batch_size, num_boxes]
      containing class predictions for the detected boxes.
184
185
186
    * detection_keypoints: (Optional) float32 tensor of shape
      [batch_size, num_boxes, num_keypoints, 2] containing keypoints for each
      detection box.
187
188
189
    * detection_masks: (Optional) float32 tensor of shape
      [batch_size, num_boxes, mask_height, mask_width] containing masks for each
      detection box.
190
191
192
193
194
195

  Args:
    postprocessed_tensors: a dictionary containing the following fields
      'detection_boxes': [batch, max_detections, 4]
      'detection_scores': [batch, max_detections]
      'detection_classes': [batch, max_detections]
196
197
      'detection_masks': [batch, max_detections, mask_height, mask_width]
        (optional).
198
      'num_detections': [batch]
199
    output_collection_name: Name of collection to add output tensors to.
200
201
202

  Returns:
    A tensor dict containing the added output tensor nodes.
203
  """
204
  detection_fields = fields.DetectionResultFields
205
  label_id_offset = 1
206
207
208
209
  boxes = postprocessed_tensors.get(detection_fields.detection_boxes)
  scores = postprocessed_tensors.get(detection_fields.detection_scores)
  classes = postprocessed_tensors.get(
      detection_fields.detection_classes) + label_id_offset
210
  keypoints = postprocessed_tensors.get(detection_fields.detection_keypoints)
211
212
  masks = postprocessed_tensors.get(detection_fields.detection_masks)
  num_detections = postprocessed_tensors.get(detection_fields.num_detections)
213
  outputs = {}
214
215
216
217
218
219
220
221
  outputs[detection_fields.detection_boxes] = tf.identity(
      boxes, name=detection_fields.detection_boxes)
  outputs[detection_fields.detection_scores] = tf.identity(
      scores, name=detection_fields.detection_scores)
  outputs[detection_fields.detection_classes] = tf.identity(
      classes, name=detection_fields.detection_classes)
  outputs[detection_fields.num_detections] = tf.identity(
      num_detections, name=detection_fields.num_detections)
222
223
224
  if keypoints is not None:
    outputs[detection_fields.detection_keypoints] = tf.identity(
        keypoints, name=detection_fields.detection_keypoints)
225
  if masks is not None:
226
227
    outputs[detection_fields.detection_masks] = tf.identity(
        masks, name=detection_fields.detection_masks)
228
229
  for output_key in outputs:
    tf.add_to_collection(output_collection_name, outputs[output_key])
230

231
  return outputs
232
233


234
235
236
237
def write_saved_model(saved_model_path,
                      frozen_graph_def,
                      inputs,
                      outputs):
238
239
240
241
242
243
244
245
246
  """Writes SavedModel to disk.

  If checkpoint_path is not None bakes the weights into the graph thereby
  eliminating the need of checkpoint files during inference. If the model
  was trained with moving averages, setting use_moving_averages to true
  restores the moving averages, otherwise the original set of variables
  is restored.

  Args:
247
248
    saved_model_path: Path to write SavedModel.
    frozen_graph_def: tf.GraphDef holding frozen graph.
249
    inputs: The input placeholder tensor.
250
251
252
253
254
    outputs: A tensor dictionary containing the outputs of a DetectionModel.
  """
  with tf.Graph().as_default():
    with session.Session() as sess:

255
      tf.import_graph_def(frozen_graph_def, name='')
256

257
      builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path)
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

      tensor_info_inputs = {
          'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
      tensor_info_outputs = {}
      for k, v in outputs.items():
        tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)

      detection_signature = (
          tf.saved_model.signature_def_utils.build_signature_def(
              inputs=tensor_info_inputs,
              outputs=tensor_info_outputs,
              method_name=signature_constants.PREDICT_METHOD_NAME))

      builder.add_meta_graph_and_variables(
          sess, [tf.saved_model.tag_constants.SERVING],
          signature_def_map={
Derek Chow's avatar
Derek Chow committed
274
              signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
275
276
277
278
279
280
                  detection_signature,
          },
      )
      builder.save()


281
282
283
284
285
def write_graph_and_checkpoint(inference_graph_def,
                               model_path,
                               input_saver_def,
                               trained_checkpoint_prefix):
  """Writes the graph and the checkpoint into disk."""
286
287
288
289
290
291
292
293
294
295
296
  for node in inference_graph_def.node:
    node.device = ''
  with tf.Graph().as_default():
    tf.import_graph_def(inference_graph_def, name='')
    with session.Session() as sess:
      saver = saver_lib.Saver(saver_def=input_saver_def,
                              save_relative_paths=True)
      saver.restore(sess, trained_checkpoint_prefix)
      saver.save(sess, model_path)


297
298
299
300
301
302
303
304
def _get_outputs_from_inputs(input_tensors, detection_model,
                             output_collection_name):
  inputs = tf.to_float(input_tensors)
  preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)
  output_tensors = detection_model.predict(
      preprocessed_inputs, true_image_shapes)
  postprocessed_tensors = detection_model.postprocess(
      output_tensors, true_image_shapes)
305
306
  return add_output_tensor_nodes(postprocessed_tensors,
                                 output_collection_name)
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334


def _build_detection_graph(input_type, detection_model, input_shape,
                           output_collection_name, graph_hook_fn):
  """Build the detection graph."""
  if input_type not in input_placeholder_fn_map:
    raise ValueError('Unknown input type: {}'.format(input_type))
  placeholder_args = {}
  if input_shape is not None:
    if input_type != 'image_tensor':
      raise ValueError('Can only specify input shape for `image_tensor` '
                       'inputs.')
    placeholder_args['input_shape'] = input_shape
  placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
      **placeholder_args)
  outputs = _get_outputs_from_inputs(
      input_tensors=input_tensors,
      detection_model=detection_model,
      output_collection_name=output_collection_name)

  # Add global step to the graph.
  slim.get_or_create_global_step()

  if graph_hook_fn: graph_hook_fn()

  return outputs, placeholder_tensor


335
336
337
def _export_inference_graph(input_type,
                            detection_model,
                            use_moving_averages,
338
339
                            trained_checkpoint_prefix,
                            output_directory,
Vivek Rathod's avatar
Vivek Rathod committed
340
341
                            additional_output_tensor_names=None,
                            input_shape=None,
342
                            output_collection_name='inference_op',
343
344
                            graph_hook_fn=None,
                            write_inference_graph=False):
345
  """Export helper."""
346
347
348
349
350
351
  tf.gfile.MakeDirs(output_directory)
  frozen_graph_path = os.path.join(output_directory,
                                   'frozen_inference_graph.pb')
  saved_model_path = os.path.join(output_directory, 'saved_model')
  model_path = os.path.join(output_directory, 'model.ckpt')

352
353
354
355
356
357
  outputs, placeholder_tensor = _build_detection_graph(
      input_type=input_type,
      detection_model=detection_model,
      input_shape=input_shape,
      output_collection_name=output_collection_name,
      graph_hook_fn=graph_hook_fn)
358

359
  profile_inference_graph(tf.get_default_graph())
360
  saver_kwargs = {}
361
  if use_moving_averages:
362
363
364
365
366
367
    # This check is to be compatible with both version of SaverDef.
    if os.path.isfile(trained_checkpoint_prefix):
      saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
      temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name
    else:
      temp_checkpoint_prefix = tempfile.mkdtemp()
Vivek Rathod's avatar
Vivek Rathod committed
368
369
    replace_variable_values_with_moving_averages(
        tf.get_default_graph(), trained_checkpoint_prefix,
370
371
        temp_checkpoint_prefix)
    checkpoint_to_use = temp_checkpoint_prefix
372
  else:
Vivek Rathod's avatar
Vivek Rathod committed
373
374
    checkpoint_to_use = trained_checkpoint_prefix

375
  saver = tf.train.Saver(**saver_kwargs)
376
377
  input_saver_def = saver.as_saver_def()

378
  write_graph_and_checkpoint(
379
380
381
      inference_graph_def=tf.get_default_graph().as_graph_def(),
      model_path=model_path,
      input_saver_def=input_saver_def,
Vivek Rathod's avatar
Vivek Rathod committed
382
      trained_checkpoint_prefix=checkpoint_to_use)
383
384
385
386
387
388
389
390
  if write_inference_graph:
    inference_graph_def = tf.get_default_graph().as_graph_def()
    inference_graph_path = os.path.join(output_directory,
                                        'inference_graph.pbtxt')
    for node in inference_graph_def.node:
      node.device = ''
    with gfile.GFile(inference_graph_path, 'wb') as f:
      f.write(str(inference_graph_def))
Vivek Rathod's avatar
Vivek Rathod committed
391
392
393
394
395

  if additional_output_tensor_names is not None:
    output_node_names = ','.join(outputs.keys()+additional_output_tensor_names)
  else:
    output_node_names = ','.join(outputs.keys())
396

397
  frozen_graph_def = freeze_graph.freeze_graph_with_def_protos(
398
399
      input_graph_def=tf.get_default_graph().as_graph_def(),
      input_saver_def=input_saver_def,
Vivek Rathod's avatar
Vivek Rathod committed
400
401
      input_checkpoint=checkpoint_to_use,
      output_node_names=output_node_names,
402
403
      restore_op_name='save/restore_all',
      filename_tensor_name='save/Const:0',
404
      output_graph=frozen_graph_path,
405
406
      clear_devices=True,
      initializer_nodes='')
407

408
409
  write_saved_model(saved_model_path, frozen_graph_def,
                    placeholder_tensor, outputs)
410
411


412
413
414
415
def export_inference_graph(input_type,
                           pipeline_config,
                           trained_checkpoint_prefix,
                           output_directory,
Vivek Rathod's avatar
Vivek Rathod committed
416
417
                           input_shape=None,
                           output_collection_name='inference_op',
418
419
                           additional_output_tensor_names=None,
                           write_inference_graph=False):
420
421
422
  """Exports inference graph for the model specified in the pipeline config.

  Args:
423
424
    input_type: Type of input for the graph. Can be one of ['image_tensor',
      'encoded_image_string_tensor', 'tf_example'].
425
    pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
426
427
    trained_checkpoint_prefix: Path to the trained checkpoint file.
    output_directory: Path to write outputs.
Vivek Rathod's avatar
Vivek Rathod committed
428
429
    input_shape: Sets a fixed shape for an `image_tensor` input. If not
      specified, will default to [None, None, None, 3].
430
431
    output_collection_name: Name of collection to add output tensors to.
      If None, does not add output tensors to a collection.
Vivek Rathod's avatar
Vivek Rathod committed
432
    additional_output_tensor_names: list of additional output
433
      tensors to include in the frozen graph.
434
    write_inference_graph: If true, writes inference graph to disk.
435
436
437
  """
  detection_model = model_builder.build(pipeline_config.model,
                                        is_training=False)
438
439
440
441
442
  graph_rewriter_fn = None
  if pipeline_config.HasField('graph_rewriter'):
    graph_rewriter_config = pipeline_config.graph_rewriter
    graph_rewriter_fn = graph_rewriter_builder.build(graph_rewriter_config,
                                                     is_training=False)
443
444
445
446
447
448
449
450
451
  _export_inference_graph(
      input_type,
      detection_model,
      pipeline_config.eval_config.use_moving_averages,
      trained_checkpoint_prefix,
      output_directory,
      additional_output_tensor_names,
      input_shape,
      output_collection_name,
452
      graph_hook_fn=graph_rewriter_fn,
453
      write_inference_graph=write_inference_graph)
454
  pipeline_config.eval_config.use_moving_averages = False
455
  config_util.save_pipeline_config(pipeline_config, output_directory)
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487


def profile_inference_graph(graph):
  """Profiles the inference graph.

  Prints model parameters and computation FLOPs given an inference graph.
  BatchNorms are excluded from the parameter count due to the fact that
  BatchNorms are usually folded. BatchNorm, Initializer, Regularizer
  and BiasAdd are not considered in FLOP count.

  Args:
    graph: the inference graph.
  """
  tfprof_vars_option = (
      tf.contrib.tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
  tfprof_flops_option = tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS

  # Batchnorm is usually folded during inference.
  tfprof_vars_option['trim_name_regexes'] = ['.*BatchNorm.*']
  # Initializer and Regularizer are only used in training.
  tfprof_flops_option['trim_name_regexes'] = [
      '.*BatchNorm.*', '.*Initializer.*', '.*Regularizer.*', '.*BiasAdd.*'
  ]

  tf.contrib.tfprof.model_analyzer.print_model_analysis(
      graph,
      tfprof_options=tfprof_vars_option)

  tf.contrib.tfprof.model_analyzer.print_model_analysis(
      graph,
      tfprof_options=tfprof_flops_option)