exporter.py 27 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions to export object detection inference graph."""
import os
Vivek Rathod's avatar
Vivek Rathod committed
18
import tempfile
19
20
import tensorflow.compat.v1 as tf
import tf_slim as slim
21
from tensorflow.core.protobuf import saver_pb2
22
from tensorflow.python.tools import freeze_graph  # pylint: disable=g-direct-tensorflow-import
23
from object_detection.builders import graph_rewriter_builder
24
25
26
from object_detection.builders import model_builder
from object_detection.core import standard_fields as fields
from object_detection.data_decoders import tf_example_decoder
27
from object_detection.utils import config_util
28
from object_detection.utils import shape_utils
29

30
31
32
33
34
35
36
37
# pylint: disable=g-import-not-at-top
try:
  from tensorflow.contrib import tfprof as contrib_tfprof
  from tensorflow.contrib.quantize.python import graph_matcher
except ImportError:
  # TF 2.0 doesn't ship with contrib.
  pass
# pylint: enable=g-import-not-at-top
38

39
freeze_graph_with_def_protos = freeze_graph.freeze_graph_with_def_protos
40
41


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def parse_side_inputs(side_input_shapes_string, side_input_names_string,
                      side_input_types_string):
  """Parses side input flags.

  Args:
    side_input_shapes_string: The shape of the side input tensors, provided as a
      comma-separated list of integers. A value of -1 is used for unknown
      dimensions. A `/` denotes a break, starting the shape of the next side
      input tensor.
    side_input_names_string: The names of the side input tensors, provided as a
      comma-separated list of strings.
    side_input_types_string: The type of the side input tensors, provided as a
      comma-separated list of types, each of `string`, `integer`, or `float`.

  Returns:
    side_input_shapes: A list of shapes.
    side_input_names: A list of strings.
    side_input_types: A list of tensorflow dtypes.

  """
  if side_input_shapes_string:
    side_input_shapes = []
    for side_input_shape_list in side_input_shapes_string.split('/'):
      side_input_shape = [
          int(dim) if dim != '-1' else None
          for dim in side_input_shape_list.split(',')
      ]
      side_input_shapes.append(side_input_shape)
  else:
    raise ValueError('When using side_inputs, side_input_shapes must be '
                     'specified in the input flags.')
  if side_input_names_string:
    side_input_names = list(side_input_names_string.split(','))
  else:
    raise ValueError('When using side_inputs, side_input_names must be '
                     'specified in the input flags.')
  if side_input_types_string:
    typelookup = {'float': tf.float32, 'int': tf.int32, 'string': tf.string}
    side_input_types = [
        typelookup[side_input_type]
        for side_input_type in side_input_types_string.split(',')
    ]
  else:
    raise ValueError('When using side_inputs, side_input_types must be '
                     'specified in the input flags.')
  return side_input_shapes, side_input_names, side_input_types


90
91
92
93
94
95
96
97
def rewrite_nn_resize_op(is_quantized=False):
  """Replaces a custom nearest-neighbor resize op with the Tensorflow version.

  Some graphs use this custom version for TPU-compatibility.

  Args:
    is_quantized: True if the default graph is quantized.
  """
98
  def remove_nn():
99
    """Remove nearest neighbor upsampling structures and replace with TF op."""
100
101
102
103
104
105
106
107
    input_pattern = graph_matcher.OpTypePattern(
        'FakeQuantWithMinMaxVars' if is_quantized else '*')
    stack_1_pattern = graph_matcher.OpTypePattern(
        'Pack', inputs=[input_pattern, input_pattern], ordered_inputs=False)
    stack_2_pattern = graph_matcher.OpTypePattern(
        'Pack', inputs=[stack_1_pattern, stack_1_pattern], ordered_inputs=False)
    reshape_pattern = graph_matcher.OpTypePattern(
        'Reshape', inputs=[stack_2_pattern, 'Const'], ordered_inputs=False)
108
    consumer_pattern1 = graph_matcher.OpTypePattern(
109
110
        'Add|AddV2|Max|Mul', inputs=[reshape_pattern, '*'],
        ordered_inputs=False)
111
112
113
    consumer_pattern2 = graph_matcher.OpTypePattern(
        'StridedSlice', inputs=[reshape_pattern, '*', '*', '*'],
        ordered_inputs=False)
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    def replace_matches(consumer_pattern):
      """Search for nearest neighbor pattern and replace with TF op."""
      match_counter = 0
      matcher = graph_matcher.GraphMatcher(consumer_pattern)
      for match in matcher.match_graph(tf.get_default_graph()):
        match_counter += 1
        projection_op = match.get_op(input_pattern)
        reshape_op = match.get_op(reshape_pattern)
        consumer_op = match.get_op(consumer_pattern)
        nn_resize = tf.image.resize_nearest_neighbor(
            projection_op.outputs[0],
            reshape_op.outputs[0].shape.dims[1:3],
            align_corners=False,
            name=os.path.split(reshape_op.name)[0] + '/resize_nearest_neighbor')

        for index, op_input in enumerate(consumer_op.inputs):
          if op_input == reshape_op.outputs[0]:
            consumer_op._update_input(index, nn_resize)  # pylint: disable=protected-access
            break

      return match_counter

    match_counter = replace_matches(consumer_pattern1)
    match_counter += replace_matches(consumer_pattern2)
139
140
141
142
143
144
145
146
147
148
149

    tf.logging.info('Found and fixed {} matches'.format(match_counter))
    return match_counter

  # Applying twice because both inputs to Add could be NN pattern
  total_removals = 0
  while remove_nn():
    total_removals += 1
    # This number is chosen based on the nas-fpn architecture.
    if total_removals > 4:
      raise ValueError('Graph removal encountered a infinite loop.')
150
151


Vivek Rathod's avatar
Vivek Rathod committed
152
153
def replace_variable_values_with_moving_averages(graph,
                                                 current_checkpoint_file,
154
155
                                                 new_checkpoint_file,
                                                 no_ema_collection=None):
Vivek Rathod's avatar
Vivek Rathod committed
156
  """Replaces variable values in the checkpoint with their moving averages.
157

Vivek Rathod's avatar
Vivek Rathod committed
158
159
160
161
162
163
164
165
166
  If the current checkpoint has shadow variables maintaining moving averages of
  the variables defined in the graph, this function generates a new checkpoint
  where the variables contain the values of their moving averages.

  Args:
    graph: a tf.Graph object.
    current_checkpoint_file: a checkpoint containing both original variables and
      their moving averages.
    new_checkpoint_file: file path to write a new checkpoint.
167
168
    no_ema_collection: A list of namescope substrings to match the variables
      to eliminate EMA.
Vivek Rathod's avatar
Vivek Rathod committed
169
170
171
172
  """
  with graph.as_default():
    variable_averages = tf.train.ExponentialMovingAverage(0.0)
    ema_variables_to_restore = variable_averages.variables_to_restore()
173
174
    ema_variables_to_restore = config_util.remove_unecessary_ema(
        ema_variables_to_restore, no_ema_collection)
Vivek Rathod's avatar
Vivek Rathod committed
175
176
177
178
179
180
181
182
183
184
185
186
187
    with tf.Session() as sess:
      read_saver = tf.train.Saver(ema_variables_to_restore)
      read_saver.restore(sess, current_checkpoint_file)
      write_saver = tf.train.Saver()
      write_saver.save(sess, new_checkpoint_file)


def _image_tensor_input_placeholder(input_shape=None):
  """Returns input placeholder and a 4-D uint8 image tensor."""
  if input_shape is None:
    input_shape = (None, None, None, 3)
  input_tensor = tf.placeholder(
      dtype=tf.uint8, shape=input_shape, name='image_tensor')
Derek Chow's avatar
Derek Chow committed
188
  return input_tensor, input_tensor
189

190

191
192
193
194
195
196
197
198
def _side_input_tensor_placeholder(side_input_shape, side_input_name,
                                   side_input_type):
  """Returns side input placeholder and side input tensor."""
  side_input_tensor = tf.placeholder(
      dtype=side_input_type, shape=side_input_shape, name=side_input_name)
  return side_input_tensor, side_input_tensor


199
def _tf_example_input_placeholder(input_shape=None):
Derek Chow's avatar
Derek Chow committed
200
201
  """Returns input that accepts a batch of strings with tf examples.

202
203
204
  Args:
    input_shape: the shape to resize the output decoded images to (optional).

Derek Chow's avatar
Derek Chow committed
205
  Returns:
Vivek Rathod's avatar
Vivek Rathod committed
206
    a tuple of input placeholder and the output decoded images.
Derek Chow's avatar
Derek Chow committed
207
  """
208
209
210
211
212
213
  batch_tf_example_placeholder = tf.placeholder(
      tf.string, shape=[None], name='tf_example')
  def decode(tf_example_string_tensor):
    tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
        tf_example_string_tensor)
    image_tensor = tensor_dict[fields.InputDataFields.image]
214
215
    if input_shape is not None:
      image_tensor = tf.image.resize(image_tensor, input_shape[1:3])
216
    return image_tensor
Derek Chow's avatar
Derek Chow committed
217
  return (batch_tf_example_placeholder,
218
219
220
221
222
223
          shape_utils.static_or_dynamic_map_fn(
              decode,
              elems=batch_tf_example_placeholder,
              dtype=tf.uint8,
              parallel_iterations=32,
              back_prop=False))
224
225


226
def _encoded_image_string_tensor_input_placeholder(input_shape=None):
Derek Chow's avatar
Derek Chow committed
227
228
  """Returns input that accepts a batch of PNG or JPEG strings.

229
230
231
  Args:
    input_shape: the shape to resize the output decoded images to (optional).

Derek Chow's avatar
Derek Chow committed
232
  Returns:
Vivek Rathod's avatar
Vivek Rathod committed
233
    a tuple of input placeholder and the output decoded images.
Derek Chow's avatar
Derek Chow committed
234
  """
235
236
237
238
239
240
241
242
  batch_image_str_placeholder = tf.placeholder(
      dtype=tf.string,
      shape=[None],
      name='encoded_image_string_tensor')
  def decode(encoded_image_string_tensor):
    image_tensor = tf.image.decode_image(encoded_image_string_tensor,
                                         channels=3)
    image_tensor.set_shape((None, None, 3))
243
244
    if input_shape is not None:
      image_tensor = tf.image.resize(image_tensor, input_shape[1:3])
245
    return image_tensor
Derek Chow's avatar
Derek Chow committed
246
247
248
249
250
251
252
  return (batch_image_str_placeholder,
          tf.map_fn(
              decode,
              elems=batch_image_str_placeholder,
              dtype=tf.uint8,
              parallel_iterations=32,
              back_prop=False))
253
254


255
input_placeholder_fn_map = {
256
257
258
    'image_tensor': _image_tensor_input_placeholder,
    'encoded_image_string_tensor':
    _encoded_image_string_tensor_input_placeholder,
259
    'tf_example': _tf_example_input_placeholder
260
261
262
}


263
264
def add_output_tensor_nodes(postprocessed_tensors,
                            output_collection_name='inference_op'):
265
266
267
268
269
270
271
272
  """Adds output nodes for detection boxes and scores.

  Adds the following nodes for output tensors -
    * num_detections: float32 tensor of shape [batch_size].
    * detection_boxes: float32 tensor of shape [batch_size, num_boxes, 4]
      containing detected boxes.
    * detection_scores: float32 tensor of shape [batch_size, num_boxes]
      containing scores for the detected boxes.
273
274
275
    * detection_multiclass_scores: (Optional) float32 tensor of shape
      [batch_size, num_boxes, num_classes_with_background] for containing class
      score distribution for detected boxes including background if any.
pkulzc's avatar
pkulzc committed
276
277
278
279
    * detection_features: (Optional) float32 tensor of shape
      [batch, num_boxes, roi_height, roi_width, depth]
      containing classifier features
      for each detected box
280
281
    * detection_classes: float32 tensor of shape [batch_size, num_boxes]
      containing class predictions for the detected boxes.
282
283
284
    * detection_keypoints: (Optional) float32 tensor of shape
      [batch_size, num_boxes, num_keypoints, 2] containing keypoints for each
      detection box.
285
286
287
    * detection_masks: (Optional) float32 tensor of shape
      [batch_size, num_boxes, mask_height, mask_width] containing masks for each
      detection box.
288
289
290
291
292

  Args:
    postprocessed_tensors: a dictionary containing the following fields
      'detection_boxes': [batch, max_detections, 4]
      'detection_scores': [batch, max_detections]
293
294
      'detection_multiclass_scores': [batch, max_detections,
        num_classes_with_background]
pkulzc's avatar
pkulzc committed
295
      'detection_features': [batch, num_boxes, roi_height, roi_width, depth]
296
      'detection_classes': [batch, max_detections]
297
298
      'detection_masks': [batch, max_detections, mask_height, mask_width]
        (optional).
299
300
      'detection_keypoints': [batch, max_detections, num_keypoints, 2]
        (optional).
301
      'num_detections': [batch]
302
    output_collection_name: Name of collection to add output tensors to.
303
304
305

  Returns:
    A tensor dict containing the added output tensor nodes.
306
  """
307
  detection_fields = fields.DetectionResultFields
308
  label_id_offset = 1
309
310
  boxes = postprocessed_tensors.get(detection_fields.detection_boxes)
  scores = postprocessed_tensors.get(detection_fields.detection_scores)
311
312
  multiclass_scores = postprocessed_tensors.get(
      detection_fields.detection_multiclass_scores)
pkulzc's avatar
pkulzc committed
313
314
  box_classifier_features = postprocessed_tensors.get(
      detection_fields.detection_features)
315
316
  raw_boxes = postprocessed_tensors.get(detection_fields.raw_detection_boxes)
  raw_scores = postprocessed_tensors.get(detection_fields.raw_detection_scores)
317
318
  classes = postprocessed_tensors.get(
      detection_fields.detection_classes) + label_id_offset
319
  keypoints = postprocessed_tensors.get(detection_fields.detection_keypoints)
320
321
  masks = postprocessed_tensors.get(detection_fields.detection_masks)
  num_detections = postprocessed_tensors.get(detection_fields.num_detections)
322
  outputs = {}
323
324
325
326
  outputs[detection_fields.detection_boxes] = tf.identity(
      boxes, name=detection_fields.detection_boxes)
  outputs[detection_fields.detection_scores] = tf.identity(
      scores, name=detection_fields.detection_scores)
327
328
329
  if multiclass_scores is not None:
    outputs[detection_fields.detection_multiclass_scores] = tf.identity(
        multiclass_scores, name=detection_fields.detection_multiclass_scores)
pkulzc's avatar
pkulzc committed
330
331
332
333
  if box_classifier_features is not None:
    outputs[detection_fields.detection_features] = tf.identity(
        box_classifier_features,
        name=detection_fields.detection_features)
334
335
336
337
  outputs[detection_fields.detection_classes] = tf.identity(
      classes, name=detection_fields.detection_classes)
  outputs[detection_fields.num_detections] = tf.identity(
      num_detections, name=detection_fields.num_detections)
338
339
340
341
342
343
  if raw_boxes is not None:
    outputs[detection_fields.raw_detection_boxes] = tf.identity(
        raw_boxes, name=detection_fields.raw_detection_boxes)
  if raw_scores is not None:
    outputs[detection_fields.raw_detection_scores] = tf.identity(
        raw_scores, name=detection_fields.raw_detection_scores)
344
345
346
  if keypoints is not None:
    outputs[detection_fields.detection_keypoints] = tf.identity(
        keypoints, name=detection_fields.detection_keypoints)
347
  if masks is not None:
348
349
    outputs[detection_fields.detection_masks] = tf.identity(
        masks, name=detection_fields.detection_masks)
350
351
  for output_key in outputs:
    tf.add_to_collection(output_collection_name, outputs[output_key])
352

353
  return outputs
354
355


356
357
358
359
def write_saved_model(saved_model_path,
                      frozen_graph_def,
                      inputs,
                      outputs):
360
361
362
363
364
365
366
367
368
  """Writes SavedModel to disk.

  If checkpoint_path is not None bakes the weights into the graph thereby
  eliminating the need of checkpoint files during inference. If the model
  was trained with moving averages, setting use_moving_averages to true
  restores the moving averages, otherwise the original set of variables
  is restored.

  Args:
369
370
    saved_model_path: Path to write SavedModel.
    frozen_graph_def: tf.GraphDef holding frozen graph.
371
    inputs: A tensor dictionary containing the inputs to a DetectionModel.
372
373
374
    outputs: A tensor dictionary containing the outputs of a DetectionModel.
  """
  with tf.Graph().as_default():
375
    with tf.Session() as sess:
376

377
      tf.import_graph_def(frozen_graph_def, name='')
378

379
      builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path)
380

381
382
383
384
385
386
387
      tensor_info_inputs = {}
      if isinstance(inputs, dict):
        for k, v in inputs.items():
          tensor_info_inputs[k] = tf.saved_model.utils.build_tensor_info(v)
      else:
        tensor_info_inputs['inputs'] = tf.saved_model.utils.build_tensor_info(
            inputs)
388
389
390
391
392
393
394
395
      tensor_info_outputs = {}
      for k, v in outputs.items():
        tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)

      detection_signature = (
          tf.saved_model.signature_def_utils.build_signature_def(
              inputs=tensor_info_inputs,
              outputs=tensor_info_outputs,
396
397
              method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
          ))
398
399

      builder.add_meta_graph_and_variables(
400
401
          sess,
          [tf.saved_model.tag_constants.SERVING],
402
          signature_def_map={
403
404
              tf.saved_model.signature_constants
              .DEFAULT_SERVING_SIGNATURE_DEF_KEY:
405
406
407
408
409
410
                  detection_signature,
          },
      )
      builder.save()


411
412
413
414
415
def write_graph_and_checkpoint(inference_graph_def,
                               model_path,
                               input_saver_def,
                               trained_checkpoint_prefix):
  """Writes the graph and the checkpoint into disk."""
416
417
418
419
  for node in inference_graph_def.node:
    node.device = ''
  with tf.Graph().as_default():
    tf.import_graph_def(inference_graph_def, name='')
420
421
422
    with tf.Session() as sess:
      saver = tf.train.Saver(
          saver_def=input_saver_def, save_relative_paths=True)
423
424
425
426
      saver.restore(sess, trained_checkpoint_prefix)
      saver.save(sess, model_path)


427
def _get_outputs_from_inputs(input_tensors, detection_model,
428
                             output_collection_name, **side_inputs):
429
  inputs = tf.cast(input_tensors, dtype=tf.float32)
430
431
  preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)
  output_tensors = detection_model.predict(
432
      preprocessed_inputs, true_image_shapes, **side_inputs)
433
434
  postprocessed_tensors = detection_model.postprocess(
      output_tensors, true_image_shapes)
435
436
  return add_output_tensor_nodes(postprocessed_tensors,
                                 output_collection_name)
437
438


439
def build_detection_graph(input_type, detection_model, input_shape,
440
441
442
                          output_collection_name, graph_hook_fn,
                          use_side_inputs=False, side_input_shapes=None,
                          side_input_names=None, side_input_types=None):
443
444
445
446
  """Build the detection graph."""
  if input_type not in input_placeholder_fn_map:
    raise ValueError('Unknown input type: {}'.format(input_type))
  placeholder_args = {}
447
  side_inputs = {}
448
  if input_shape is not None:
449
450
    if (input_type != 'image_tensor' and
        input_type != 'encoded_image_string_tensor' and
451
452
        input_type != 'tf_example' and
        input_type != 'tf_sequence_example'):
453
      raise ValueError('Can only specify input shape for `image_tensor`, '
454
455
                       '`encoded_image_string_tensor`, `tf_example`, '
                       ' or `tf_sequence_example` inputs.')
456
457
458
    placeholder_args['input_shape'] = input_shape
  placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
      **placeholder_args)
459
460
461
462
463
464
465
466
  placeholder_tensors = {'inputs': placeholder_tensor}
  if use_side_inputs:
    for idx, side_input_name in enumerate(side_input_names):
      side_input_placeholder, side_input = _side_input_tensor_placeholder(
          side_input_shapes[idx], side_input_name, side_input_types[idx])
      print(side_input)
      side_inputs[side_input_name] = side_input
      placeholder_tensors[side_input_name] = side_input_placeholder
467
468
469
  outputs = _get_outputs_from_inputs(
      input_tensors=input_tensors,
      detection_model=detection_model,
470
471
      output_collection_name=output_collection_name,
      **side_inputs)
472
473
474
475
476
477

  # Add global step to the graph.
  slim.get_or_create_global_step()

  if graph_hook_fn: graph_hook_fn()

478
  return outputs, placeholder_tensors
479
480


481
482
483
def _export_inference_graph(input_type,
                            detection_model,
                            use_moving_averages,
484
485
                            trained_checkpoint_prefix,
                            output_directory,
Vivek Rathod's avatar
Vivek Rathod committed
486
487
                            additional_output_tensor_names=None,
                            input_shape=None,
488
                            output_collection_name='inference_op',
489
                            graph_hook_fn=None,
490
                            write_inference_graph=False,
491
492
493
494
495
                            temp_checkpoint_prefix='',
                            use_side_inputs=False,
                            side_input_shapes=None,
                            side_input_names=None,
                            side_input_types=None):
496
  """Export helper."""
497
498
499
500
501
502
  tf.gfile.MakeDirs(output_directory)
  frozen_graph_path = os.path.join(output_directory,
                                   'frozen_inference_graph.pb')
  saved_model_path = os.path.join(output_directory, 'saved_model')
  model_path = os.path.join(output_directory, 'model.ckpt')

503
  outputs, placeholder_tensor_dict = build_detection_graph(
504
505
506
507
      input_type=input_type,
      detection_model=detection_model,
      input_shape=input_shape,
      output_collection_name=output_collection_name,
508
509
510
511
512
      graph_hook_fn=graph_hook_fn,
      use_side_inputs=use_side_inputs,
      side_input_shapes=side_input_shapes,
      side_input_names=side_input_names,
      side_input_types=side_input_types)
513

514
  profile_inference_graph(tf.get_default_graph())
515
  saver_kwargs = {}
516
  if use_moving_averages:
517
518
519
520
521
522
523
    if not temp_checkpoint_prefix:
      # This check is to be compatible with both version of SaverDef.
      if os.path.isfile(trained_checkpoint_prefix):
        saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
        temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name
      else:
        temp_checkpoint_prefix = tempfile.mkdtemp()
Vivek Rathod's avatar
Vivek Rathod committed
524
525
    replace_variable_values_with_moving_averages(
        tf.get_default_graph(), trained_checkpoint_prefix,
526
527
        temp_checkpoint_prefix)
    checkpoint_to_use = temp_checkpoint_prefix
528
  else:
Vivek Rathod's avatar
Vivek Rathod committed
529
530
    checkpoint_to_use = trained_checkpoint_prefix

531
  saver = tf.train.Saver(**saver_kwargs)
532
533
  input_saver_def = saver.as_saver_def()

534
  write_graph_and_checkpoint(
535
536
537
      inference_graph_def=tf.get_default_graph().as_graph_def(),
      model_path=model_path,
      input_saver_def=input_saver_def,
Vivek Rathod's avatar
Vivek Rathod committed
538
      trained_checkpoint_prefix=checkpoint_to_use)
539
540
541
542
543
544
  if write_inference_graph:
    inference_graph_def = tf.get_default_graph().as_graph_def()
    inference_graph_path = os.path.join(output_directory,
                                        'inference_graph.pbtxt')
    for node in inference_graph_def.node:
      node.device = ''
545
    with tf.gfile.GFile(inference_graph_path, 'wb') as f:
546
      f.write(str(inference_graph_def))
Vivek Rathod's avatar
Vivek Rathod committed
547
548

  if additional_output_tensor_names is not None:
549
550
    output_node_names = ','.join(list(outputs.keys())+(
        additional_output_tensor_names))
Vivek Rathod's avatar
Vivek Rathod committed
551
552
  else:
    output_node_names = ','.join(outputs.keys())
553

554
  frozen_graph_def = freeze_graph.freeze_graph_with_def_protos(
555
556
      input_graph_def=tf.get_default_graph().as_graph_def(),
      input_saver_def=input_saver_def,
Vivek Rathod's avatar
Vivek Rathod committed
557
558
      input_checkpoint=checkpoint_to_use,
      output_node_names=output_node_names,
559
560
      restore_op_name='save/restore_all',
      filename_tensor_name='save/Const:0',
561
      output_graph=frozen_graph_path,
562
563
      clear_devices=True,
      initializer_nodes='')
564

565
  write_saved_model(saved_model_path, frozen_graph_def,
566
                    placeholder_tensor_dict, outputs)
567
568


569
570
571
572
def export_inference_graph(input_type,
                           pipeline_config,
                           trained_checkpoint_prefix,
                           output_directory,
Vivek Rathod's avatar
Vivek Rathod committed
573
574
                           input_shape=None,
                           output_collection_name='inference_op',
575
                           additional_output_tensor_names=None,
576
577
578
579
580
                           write_inference_graph=False,
                           use_side_inputs=False,
                           side_input_shapes=None,
                           side_input_names=None,
                           side_input_types=None):
581
582
583
  """Exports inference graph for the model specified in the pipeline config.

  Args:
584
585
    input_type: Type of input for the graph. Can be one of ['image_tensor',
      'encoded_image_string_tensor', 'tf_example'].
586
    pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
587
588
    trained_checkpoint_prefix: Path to the trained checkpoint file.
    output_directory: Path to write outputs.
Vivek Rathod's avatar
Vivek Rathod committed
589
590
    input_shape: Sets a fixed shape for an `image_tensor` input. If not
      specified, will default to [None, None, None, 3].
591
592
    output_collection_name: Name of collection to add output tensors to.
      If None, does not add output tensors to a collection.
Vivek Rathod's avatar
Vivek Rathod committed
593
    additional_output_tensor_names: list of additional output
594
      tensors to include in the frozen graph.
595
    write_inference_graph: If true, writes inference graph to disk.
596
597
598
599
600
601
602
    use_side_inputs: If True, the model requires side_inputs.
    side_input_shapes: List of shapes of the side input tensors,
      required if use_side_inputs is True.
    side_input_names: List of names of the side input tensors,
      required if use_side_inputs is True.
    side_input_types: List of types of the side input tensors,
      required if use_side_inputs is True.
603
604
605
  """
  detection_model = model_builder.build(pipeline_config.model,
                                        is_training=False)
606
607
608
609
610
  graph_rewriter_fn = None
  if pipeline_config.HasField('graph_rewriter'):
    graph_rewriter_config = pipeline_config.graph_rewriter
    graph_rewriter_fn = graph_rewriter_builder.build(graph_rewriter_config,
                                                     is_training=False)
611
612
613
614
615
616
617
618
619
  _export_inference_graph(
      input_type,
      detection_model,
      pipeline_config.eval_config.use_moving_averages,
      trained_checkpoint_prefix,
      output_directory,
      additional_output_tensor_names,
      input_shape,
      output_collection_name,
620
      graph_hook_fn=graph_rewriter_fn,
621
622
623
624
625
      write_inference_graph=write_inference_graph,
      use_side_inputs=use_side_inputs,
      side_input_shapes=side_input_shapes,
      side_input_names=side_input_names,
      side_input_types=side_input_types)
626
  pipeline_config.eval_config.use_moving_averages = False
627
  config_util.save_pipeline_config(pipeline_config, output_directory)
628
629
630
631
632
633
634
635
636
637
638
639
640
641


def profile_inference_graph(graph):
  """Profiles the inference graph.

  Prints model parameters and computation FLOPs given an inference graph.
  BatchNorms are excluded from the parameter count due to the fact that
  BatchNorms are usually folded. BatchNorm, Initializer, Regularizer
  and BiasAdd are not considered in FLOP count.

  Args:
    graph: the inference graph.
  """
  tfprof_vars_option = (
642
643
      contrib_tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
  tfprof_flops_option = contrib_tfprof.model_analyzer.FLOAT_OPS_OPTIONS
644
645
646
647
648
649
650
651

  # Batchnorm is usually folded during inference.
  tfprof_vars_option['trim_name_regexes'] = ['.*BatchNorm.*']
  # Initializer and Regularizer are only used in training.
  tfprof_flops_option['trim_name_regexes'] = [
      '.*BatchNorm.*', '.*Initializer.*', '.*Regularizer.*', '.*BiasAdd.*'
  ]

652
653
  contrib_tfprof.model_analyzer.print_model_analysis(
      graph, tfprof_options=tfprof_vars_option)
654

655
656
  contrib_tfprof.model_analyzer.print_model_analysis(
      graph, tfprof_options=tfprof_flops_option)