exporter.py 14.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions to export object detection inference graph."""
import logging
import os
import tensorflow as tf
20
from tensorflow.core.protobuf import rewriter_config_pb2
21
22
23
24
25
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import gfile
26
from tensorflow.python.saved_model import signature_constants
27
28
29
30
31
32
33
34
from tensorflow.python.training import saver as saver_lib
from object_detection.builders import model_builder
from object_detection.core import standard_fields as fields
from object_detection.data_decoders import tf_example_decoder

slim = tf.contrib.slim


35
36
# TODO: Replace with freeze_graph.freeze_graph_with_def_protos when
# newer version of Tensorflow becomes more common.
37
38
39
40
41
42
43
44
45
def freeze_graph_with_def_protos(
    input_graph_def,
    input_saver_def,
    input_checkpoint,
    output_node_names,
    restore_op_name,
    filename_tensor_name,
    clear_devices,
    initializer_nodes,
Derek Chow's avatar
Derek Chow committed
46
    optimize_graph=False,
47
48
49
50
51
52
    variable_names_blacklist=''):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not saver_lib.checkpoint_exists(input_checkpoint):
53
54
    raise ValueError(
        'Input checkpoint "' + input_checkpoint + '" does not exist!')
55
56

  if not output_node_names:
57
58
    raise ValueError(
        'You must supply the name of a node to --output_node_names.')
59
60
61
62
63
64
65

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ''

66
67
68
69
70
71
72
73
74
75
76
77
  with tf.Graph().as_default():
    tf.import_graph_def(input_graph_def, name='')

    if optimize_graph:
      logging.info('Graph Rewriter optimizations enabled')
      rewrite_options = rewriter_config_pb2.RewriterConfig(
          optimize_tensor_layout=True)
      rewrite_options.optimizers.append('pruning')
      rewrite_options.optimizers.append('constfold')
      rewrite_options.optimizers.append('layout')
      graph_options = tf.GraphOptions(
          rewrite_options=rewrite_options, infer_shapes=True)
78
    else:
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
      logging.info('Graph Rewriter optimizations disabled')
      graph_options = tf.GraphOptions()
    config = tf.ConfigProto(graph_options=graph_options)
    with session.Session(config=config) as sess:
      if input_saver_def:
        saver = saver_lib.Saver(saver_def=input_saver_def)
        saver.restore(sess, input_checkpoint)
      else:
        var_list = {}
        reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in var_to_shape_map:
          try:
            tensor = sess.graph.get_tensor_by_name(key + ':0')
          except KeyError:
            # This tensor doesn't exist in the graph (for example it's
            # 'global_step' or a similar housekeeping element) so skip it.
            continue
          var_list[key] = tensor
        saver = saver_lib.Saver(var_list=var_list)
        saver.restore(sess, input_checkpoint)
        if initializer_nodes:
          sess.run(initializer_nodes)

      variable_names_blacklist = (variable_names_blacklist.split(',') if
                                  variable_names_blacklist else None)
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.split(','),
          variable_names_blacklist=variable_names_blacklist)
110

111
112
113
  return output_graph_def


114
115

def _image_tensor_input_placeholder():
116
  """Returns input node that accepts a batch of uint8 images."""
117
  return tf.placeholder(dtype=tf.uint8,
118
                        shape=(None, None, None, 3),
119
120
                        name='image_tensor')

121

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def _tf_example_input_placeholder():
  """Returns input node that accepts a batch of strings with tf examples."""
  batch_tf_example_placeholder = tf.placeholder(
      tf.string, shape=[None], name='tf_example')
  def decode(tf_example_string_tensor):
    tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
        tf_example_string_tensor)
    image_tensor = tensor_dict[fields.InputDataFields.image]
    return image_tensor
  return tf.map_fn(decode,
                   elems=batch_tf_example_placeholder,
                   dtype=tf.uint8,
                   parallel_iterations=32,
                   back_prop=False)


138
def _encoded_image_string_tensor_input_placeholder():
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
  """Returns input node that accepts a batch of PNG or JPEG strings."""
  batch_image_str_placeholder = tf.placeholder(
      dtype=tf.string,
      shape=[None],
      name='encoded_image_string_tensor')
  def decode(encoded_image_string_tensor):
    image_tensor = tf.image.decode_image(encoded_image_string_tensor,
                                         channels=3)
    image_tensor.set_shape((None, None, 3))
    return image_tensor
  return tf.map_fn(decode,
                   elems=batch_image_str_placeholder,
                   dtype=tf.uint8,
                   parallel_iterations=32,
                   back_prop=False)
154
155


156
input_placeholder_fn_map = {
157
158
159
    'image_tensor': _image_tensor_input_placeholder,
    'encoded_image_string_tensor':
    _encoded_image_string_tensor_input_placeholder,
160
161
162
163
    'tf_example': _tf_example_input_placeholder,
}


164
165
def _add_output_tensor_nodes(postprocessed_tensors,
                             output_collection_name='inference_op'):
166
167
168
169
170
171
172
173
174
175
  """Adds output nodes for detection boxes and scores.

  Adds the following nodes for output tensors -
    * num_detections: float32 tensor of shape [batch_size].
    * detection_boxes: float32 tensor of shape [batch_size, num_boxes, 4]
      containing detected boxes.
    * detection_scores: float32 tensor of shape [batch_size, num_boxes]
      containing scores for the detected boxes.
    * detection_classes: float32 tensor of shape [batch_size, num_boxes]
      containing class predictions for the detected boxes.
176
177
178
    * detection_masks: (Optional) float32 tensor of shape
      [batch_size, num_boxes, mask_height, mask_width] containing masks for each
      detection box.
179
180
181
182
183
184

  Args:
    postprocessed_tensors: a dictionary containing the following fields
      'detection_boxes': [batch, max_detections, 4]
      'detection_scores': [batch, max_detections]
      'detection_classes': [batch, max_detections]
185
186
      'detection_masks': [batch, max_detections, mask_height, mask_width]
        (optional).
187
      'num_detections': [batch]
188
    output_collection_name: Name of collection to add output tensors to.
189
190
191

  Returns:
    A tensor dict containing the added output tensor nodes.
192
193
194
195
196
  """
  label_id_offset = 1
  boxes = postprocessed_tensors.get('detection_boxes')
  scores = postprocessed_tensors.get('detection_scores')
  classes = postprocessed_tensors.get('detection_classes') + label_id_offset
197
  masks = postprocessed_tensors.get('detection_masks')
198
  num_detections = postprocessed_tensors.get('num_detections')
199
200
201
202
203
  outputs = {}
  outputs['detection_boxes'] = tf.identity(boxes, name='detection_boxes')
  outputs['detection_scores'] = tf.identity(scores, name='detection_scores')
  outputs['detection_classes'] = tf.identity(classes, name='detection_classes')
  outputs['num_detections'] = tf.identity(num_detections, name='num_detections')
204
  if masks is not None:
205
    outputs['detection_masks'] = tf.identity(masks, name='detection_masks')
206
207
208
209
  for output_key in outputs:
    tf.add_to_collection(output_collection_name, outputs[output_key])
  if masks is not None:
    tf.add_to_collection(output_collection_name, outputs['detection_masks'])
210
  return outputs
211
212


213
214
def _write_frozen_graph(frozen_graph_path, frozen_graph_def):
  """Writes frozen graph to disk.
215
216

  Args:
217
218
    frozen_graph_path: Path to write inference graph.
    frozen_graph_def: tf.GraphDef holding frozen graph.
219
  """
220
221
222
223
224
225
226
227
228
  with gfile.GFile(frozen_graph_path, 'wb') as f:
    f.write(frozen_graph_def.SerializeToString())
  logging.info('%d ops in the final graph.', len(frozen_graph_def.node))


def _write_saved_model(saved_model_path,
                       frozen_graph_def,
                       inputs,
                       outputs):
229
230
231
232
233
234
235
236
237
  """Writes SavedModel to disk.

  If checkpoint_path is not None bakes the weights into the graph thereby
  eliminating the need of checkpoint files during inference. If the model
  was trained with moving averages, setting use_moving_averages to true
  restores the moving averages, otherwise the original set of variables
  is restored.

  Args:
238
239
    saved_model_path: Path to write SavedModel.
    frozen_graph_def: tf.GraphDef holding frozen graph.
240
241
242
243
244
245
    inputs: The input image tensor to use for detection.
    outputs: A tensor dictionary containing the outputs of a DetectionModel.
  """
  with tf.Graph().as_default():
    with session.Session() as sess:

246
      tf.import_graph_def(frozen_graph_def, name='')
247

248
      builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path)
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

      tensor_info_inputs = {
          'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
      tensor_info_outputs = {}
      for k, v in outputs.items():
        tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)

      detection_signature = (
          tf.saved_model.signature_def_utils.build_signature_def(
              inputs=tensor_info_inputs,
              outputs=tensor_info_outputs,
              method_name=signature_constants.PREDICT_METHOD_NAME))

      builder.add_meta_graph_and_variables(
          sess, [tf.saved_model.tag_constants.SERVING],
          signature_def_map={
265
              'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY':
266
267
268
269
270
271
                  detection_signature,
          },
      )
      builder.save()


272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def _write_graph_and_checkpoint(inference_graph_def,
                                model_path,
                                input_saver_def,
                                trained_checkpoint_prefix):
  for node in inference_graph_def.node:
    node.device = ''
  with tf.Graph().as_default():
    tf.import_graph_def(inference_graph_def, name='')
    with session.Session() as sess:
      saver = saver_lib.Saver(saver_def=input_saver_def,
                              save_relative_paths=True)
      saver.restore(sess, trained_checkpoint_prefix)
      saver.save(sess, model_path)


287
288
289
def _export_inference_graph(input_type,
                            detection_model,
                            use_moving_averages,
290
291
                            trained_checkpoint_prefix,
                            output_directory,
Derek Chow's avatar
Derek Chow committed
292
                            optimize_graph=False,
293
                            output_collection_name='inference_op'):
294
  """Export helper."""
295
296
297
298
299
300
  tf.gfile.MakeDirs(output_directory)
  frozen_graph_path = os.path.join(output_directory,
                                   'frozen_inference_graph.pb')
  saved_model_path = os.path.join(output_directory, 'saved_model')
  model_path = os.path.join(output_directory, 'model.ckpt')

301
302
303
304
305
306
  if input_type not in input_placeholder_fn_map:
    raise ValueError('Unknown input type: {}'.format(input_type))
  inputs = tf.to_float(input_placeholder_fn_map[input_type]())
  preprocessed_inputs = detection_model.preprocess(inputs)
  output_tensors = detection_model.predict(preprocessed_inputs)
  postprocessed_tensors = detection_model.postprocess(output_tensors)
307
308
309
310
311
312
313
314
  outputs = _add_output_tensor_nodes(postprocessed_tensors,
                                     output_collection_name)

  saver = None
  if use_moving_averages:
    variable_averages = tf.train.ExponentialMovingAverage(0.0)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
315
  else:
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    saver = tf.train.Saver()
  input_saver_def = saver.as_saver_def()

  _write_graph_and_checkpoint(
      inference_graph_def=tf.get_default_graph().as_graph_def(),
      model_path=model_path,
      input_saver_def=input_saver_def,
      trained_checkpoint_prefix=trained_checkpoint_prefix)

  frozen_graph_def = freeze_graph_with_def_protos(
      input_graph_def=tf.get_default_graph().as_graph_def(),
      input_saver_def=input_saver_def,
      input_checkpoint=trained_checkpoint_prefix,
      output_node_names=','.join(outputs.keys()),
      restore_op_name='save/restore_all',
      filename_tensor_name='save/Const:0',
      clear_devices=True,
      optimize_graph=optimize_graph,
      initializer_nodes='')
  _write_frozen_graph(frozen_graph_path, frozen_graph_def)
  _write_saved_model(saved_model_path, frozen_graph_def, inputs, outputs)
337
338


339
340
341
342
def export_inference_graph(input_type,
                           pipeline_config,
                           trained_checkpoint_prefix,
                           output_directory,
Derek Chow's avatar
Derek Chow committed
343
                           optimize_graph=False,
344
                           output_collection_name='inference_op'):
345
346
347
348
349
350
  """Exports inference graph for the model specified in the pipeline config.

  Args:
    input_type: Type of input for the graph. Can be one of [`image_tensor`,
      `tf_example`].
    pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
351
352
353
354
355
    trained_checkpoint_prefix: Path to the trained checkpoint file.
    output_directory: Path to write outputs.
    optimize_graph: Whether to optimize graph using Grappler.
    output_collection_name: Name of collection to add output tensors to.
      If None, does not add output tensors to a collection.
356
357
358
359
360
  """
  detection_model = model_builder.build(pipeline_config.model,
                                        is_training=False)
  _export_inference_graph(input_type, detection_model,
                          pipeline_config.eval_config.use_moving_averages,
361
362
                          trained_checkpoint_prefix, output_directory,
                          optimize_graph, output_collection_name)