exporter.py 14.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions to export object detection inference graph."""
import logging
import os
import tensorflow as tf
20
from tensorflow.core.protobuf import rewriter_config_pb2
21
22
23
24
25
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import gfile
26
from tensorflow.python.saved_model import signature_constants
27
28
29
30
31
32
33
34
from tensorflow.python.training import saver as saver_lib
from object_detection.builders import model_builder
from object_detection.core import standard_fields as fields
from object_detection.data_decoders import tf_example_decoder

slim = tf.contrib.slim


35
36
# TODO: Replace with freeze_graph.freeze_graph_with_def_protos when
# newer version of Tensorflow becomes more common.
37
38
39
40
41
42
43
44
45
def freeze_graph_with_def_protos(
    input_graph_def,
    input_saver_def,
    input_checkpoint,
    output_node_names,
    restore_op_name,
    filename_tensor_name,
    clear_devices,
    initializer_nodes,
Derek Chow's avatar
Derek Chow committed
46
    optimize_graph=False,
47
48
49
50
51
52
    variable_names_blacklist=''):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if not saver_lib.checkpoint_exists(input_checkpoint):
53
54
    raise ValueError(
        'Input checkpoint "' + input_checkpoint + '" does not exist!')
55
56

  if not output_node_names:
57
58
    raise ValueError(
        'You must supply the name of a node to --output_node_names.')
59
60
61
62
63
64
65

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    for node in input_graph_def.node:
      node.device = ''

66
67
68
69
70
71
72
73
74
75
76
77
  with tf.Graph().as_default():
    tf.import_graph_def(input_graph_def, name='')

    if optimize_graph:
      logging.info('Graph Rewriter optimizations enabled')
      rewrite_options = rewriter_config_pb2.RewriterConfig(
          optimize_tensor_layout=True)
      rewrite_options.optimizers.append('pruning')
      rewrite_options.optimizers.append('constfold')
      rewrite_options.optimizers.append('layout')
      graph_options = tf.GraphOptions(
          rewrite_options=rewrite_options, infer_shapes=True)
78
    else:
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
      logging.info('Graph Rewriter optimizations disabled')
      graph_options = tf.GraphOptions()
    config = tf.ConfigProto(graph_options=graph_options)
    with session.Session(config=config) as sess:
      if input_saver_def:
        saver = saver_lib.Saver(saver_def=input_saver_def)
        saver.restore(sess, input_checkpoint)
      else:
        var_list = {}
        reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in var_to_shape_map:
          try:
            tensor = sess.graph.get_tensor_by_name(key + ':0')
          except KeyError:
            # This tensor doesn't exist in the graph (for example it's
            # 'global_step' or a similar housekeeping element) so skip it.
            continue
          var_list[key] = tensor
        saver = saver_lib.Saver(var_list=var_list)
        saver.restore(sess, input_checkpoint)
        if initializer_nodes:
          sess.run(initializer_nodes)

      variable_names_blacklist = (variable_names_blacklist.split(',') if
                                  variable_names_blacklist else None)
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.split(','),
          variable_names_blacklist=variable_names_blacklist)
110

111
112
113
  return output_graph_def


114
115

def _image_tensor_input_placeholder():
Derek Chow's avatar
Derek Chow committed
116
117
118
119
120
  """Returns placeholder and input node that accepts a batch of uint8 images."""
  input_tensor = tf.placeholder(dtype=tf.uint8,
                                shape=(None, None, None, 3),
                                name='image_tensor')
  return input_tensor, input_tensor
121

122

123
def _tf_example_input_placeholder():
Derek Chow's avatar
Derek Chow committed
124
125
126
127
128
  """Returns input that accepts a batch of strings with tf examples.

  Returns:
    a tuple of placeholder and input nodes that output decoded images.
  """
129
130
131
132
133
134
135
  batch_tf_example_placeholder = tf.placeholder(
      tf.string, shape=[None], name='tf_example')
  def decode(tf_example_string_tensor):
    tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
        tf_example_string_tensor)
    image_tensor = tensor_dict[fields.InputDataFields.image]
    return image_tensor
Derek Chow's avatar
Derek Chow committed
136
137
138
139
140
141
  return (batch_tf_example_placeholder,
          tf.map_fn(decode,
                    elems=batch_tf_example_placeholder,
                    dtype=tf.uint8,
                    parallel_iterations=32,
                    back_prop=False))
142
143


144
def _encoded_image_string_tensor_input_placeholder():
Derek Chow's avatar
Derek Chow committed
145
146
147
148
149
  """Returns input that accepts a batch of PNG or JPEG strings.

  Returns:
    a tuple of placeholder and input nodes that output decoded images.
  """
150
151
152
153
154
155
156
157
158
  batch_image_str_placeholder = tf.placeholder(
      dtype=tf.string,
      shape=[None],
      name='encoded_image_string_tensor')
  def decode(encoded_image_string_tensor):
    image_tensor = tf.image.decode_image(encoded_image_string_tensor,
                                         channels=3)
    image_tensor.set_shape((None, None, 3))
    return image_tensor
Derek Chow's avatar
Derek Chow committed
159
160
161
162
163
164
165
  return (batch_image_str_placeholder,
          tf.map_fn(
              decode,
              elems=batch_image_str_placeholder,
              dtype=tf.uint8,
              parallel_iterations=32,
              back_prop=False))
166
167


168
input_placeholder_fn_map = {
169
170
171
    'image_tensor': _image_tensor_input_placeholder,
    'encoded_image_string_tensor':
    _encoded_image_string_tensor_input_placeholder,
172
173
174
175
    'tf_example': _tf_example_input_placeholder,
}


176
177
def _add_output_tensor_nodes(postprocessed_tensors,
                             output_collection_name='inference_op'):
178
179
180
181
182
183
184
185
186
187
  """Adds output nodes for detection boxes and scores.

  Adds the following nodes for output tensors -
    * num_detections: float32 tensor of shape [batch_size].
    * detection_boxes: float32 tensor of shape [batch_size, num_boxes, 4]
      containing detected boxes.
    * detection_scores: float32 tensor of shape [batch_size, num_boxes]
      containing scores for the detected boxes.
    * detection_classes: float32 tensor of shape [batch_size, num_boxes]
      containing class predictions for the detected boxes.
188
189
190
    * detection_masks: (Optional) float32 tensor of shape
      [batch_size, num_boxes, mask_height, mask_width] containing masks for each
      detection box.
191
192
193
194
195
196

  Args:
    postprocessed_tensors: a dictionary containing the following fields
      'detection_boxes': [batch, max_detections, 4]
      'detection_scores': [batch, max_detections]
      'detection_classes': [batch, max_detections]
197
198
      'detection_masks': [batch, max_detections, mask_height, mask_width]
        (optional).
199
      'num_detections': [batch]
200
    output_collection_name: Name of collection to add output tensors to.
201
202
203

  Returns:
    A tensor dict containing the added output tensor nodes.
204
205
206
207
208
  """
  label_id_offset = 1
  boxes = postprocessed_tensors.get('detection_boxes')
  scores = postprocessed_tensors.get('detection_scores')
  classes = postprocessed_tensors.get('detection_classes') + label_id_offset
209
  masks = postprocessed_tensors.get('detection_masks')
210
  num_detections = postprocessed_tensors.get('num_detections')
211
212
213
214
215
  outputs = {}
  outputs['detection_boxes'] = tf.identity(boxes, name='detection_boxes')
  outputs['detection_scores'] = tf.identity(scores, name='detection_scores')
  outputs['detection_classes'] = tf.identity(classes, name='detection_classes')
  outputs['num_detections'] = tf.identity(num_detections, name='num_detections')
216
  if masks is not None:
217
    outputs['detection_masks'] = tf.identity(masks, name='detection_masks')
218
219
220
221
  for output_key in outputs:
    tf.add_to_collection(output_collection_name, outputs[output_key])
  if masks is not None:
    tf.add_to_collection(output_collection_name, outputs['detection_masks'])
222
  return outputs
223
224


225
226
def _write_frozen_graph(frozen_graph_path, frozen_graph_def):
  """Writes frozen graph to disk.
227
228

  Args:
229
230
    frozen_graph_path: Path to write inference graph.
    frozen_graph_def: tf.GraphDef holding frozen graph.
231
  """
232
233
234
235
236
237
238
239
240
  with gfile.GFile(frozen_graph_path, 'wb') as f:
    f.write(frozen_graph_def.SerializeToString())
  logging.info('%d ops in the final graph.', len(frozen_graph_def.node))


def _write_saved_model(saved_model_path,
                       frozen_graph_def,
                       inputs,
                       outputs):
241
242
243
244
245
246
247
248
249
  """Writes SavedModel to disk.

  If checkpoint_path is not None bakes the weights into the graph thereby
  eliminating the need of checkpoint files during inference. If the model
  was trained with moving averages, setting use_moving_averages to true
  restores the moving averages, otherwise the original set of variables
  is restored.

  Args:
250
251
    saved_model_path: Path to write SavedModel.
    frozen_graph_def: tf.GraphDef holding frozen graph.
252
253
254
255
256
257
    inputs: The input image tensor to use for detection.
    outputs: A tensor dictionary containing the outputs of a DetectionModel.
  """
  with tf.Graph().as_default():
    with session.Session() as sess:

258
      tf.import_graph_def(frozen_graph_def, name='')
259

260
      builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path)
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

      tensor_info_inputs = {
          'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
      tensor_info_outputs = {}
      for k, v in outputs.items():
        tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)

      detection_signature = (
          tf.saved_model.signature_def_utils.build_signature_def(
              inputs=tensor_info_inputs,
              outputs=tensor_info_outputs,
              method_name=signature_constants.PREDICT_METHOD_NAME))

      builder.add_meta_graph_and_variables(
          sess, [tf.saved_model.tag_constants.SERVING],
          signature_def_map={
Derek Chow's avatar
Derek Chow committed
277
              signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
278
279
280
281
282
283
                  detection_signature,
          },
      )
      builder.save()


284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def _write_graph_and_checkpoint(inference_graph_def,
                                model_path,
                                input_saver_def,
                                trained_checkpoint_prefix):
  for node in inference_graph_def.node:
    node.device = ''
  with tf.Graph().as_default():
    tf.import_graph_def(inference_graph_def, name='')
    with session.Session() as sess:
      saver = saver_lib.Saver(saver_def=input_saver_def,
                              save_relative_paths=True)
      saver.restore(sess, trained_checkpoint_prefix)
      saver.save(sess, model_path)


299
300
301
def _export_inference_graph(input_type,
                            detection_model,
                            use_moving_averages,
302
303
                            trained_checkpoint_prefix,
                            output_directory,
Derek Chow's avatar
Derek Chow committed
304
                            optimize_graph=False,
305
                            output_collection_name='inference_op'):
306
  """Export helper."""
307
308
309
310
311
312
  tf.gfile.MakeDirs(output_directory)
  frozen_graph_path = os.path.join(output_directory,
                                   'frozen_inference_graph.pb')
  saved_model_path = os.path.join(output_directory, 'saved_model')
  model_path = os.path.join(output_directory, 'model.ckpt')

313
314
  if input_type not in input_placeholder_fn_map:
    raise ValueError('Unknown input type: {}'.format(input_type))
Derek Chow's avatar
Derek Chow committed
315
316
  placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type]()
  inputs = tf.to_float(input_tensors)
317
318
319
  preprocessed_inputs = detection_model.preprocess(inputs)
  output_tensors = detection_model.predict(preprocessed_inputs)
  postprocessed_tensors = detection_model.postprocess(output_tensors)
320
321
322
323
324
325
326
327
  outputs = _add_output_tensor_nodes(postprocessed_tensors,
                                     output_collection_name)

  saver = None
  if use_moving_averages:
    variable_averages = tf.train.ExponentialMovingAverage(0.0)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
328
  else:
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
    saver = tf.train.Saver()
  input_saver_def = saver.as_saver_def()

  _write_graph_and_checkpoint(
      inference_graph_def=tf.get_default_graph().as_graph_def(),
      model_path=model_path,
      input_saver_def=input_saver_def,
      trained_checkpoint_prefix=trained_checkpoint_prefix)

  frozen_graph_def = freeze_graph_with_def_protos(
      input_graph_def=tf.get_default_graph().as_graph_def(),
      input_saver_def=input_saver_def,
      input_checkpoint=trained_checkpoint_prefix,
      output_node_names=','.join(outputs.keys()),
      restore_op_name='save/restore_all',
      filename_tensor_name='save/Const:0',
      clear_devices=True,
      optimize_graph=optimize_graph,
      initializer_nodes='')
  _write_frozen_graph(frozen_graph_path, frozen_graph_def)
Derek Chow's avatar
Derek Chow committed
349
350
  _write_saved_model(saved_model_path, frozen_graph_def, placeholder_tensor,
                     outputs)
351
352


353
354
355
356
def export_inference_graph(input_type,
                           pipeline_config,
                           trained_checkpoint_prefix,
                           output_directory,
Derek Chow's avatar
Derek Chow committed
357
                           optimize_graph=False,
358
                           output_collection_name='inference_op'):
359
360
361
362
363
364
  """Exports inference graph for the model specified in the pipeline config.

  Args:
    input_type: Type of input for the graph. Can be one of [`image_tensor`,
      `tf_example`].
    pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
365
366
367
368
369
    trained_checkpoint_prefix: Path to the trained checkpoint file.
    output_directory: Path to write outputs.
    optimize_graph: Whether to optimize graph using Grappler.
    output_collection_name: Name of collection to add output tensors to.
      If None, does not add output tensors to a collection.
370
371
372
373
374
  """
  detection_model = model_builder.build(pipeline_config.model,
                                        is_training=False)
  _export_inference_graph(input_type, detection_model,
                          pipeline_config.eval_config.use_moving_averages,
375
376
                          trained_checkpoint_prefix, output_directory,
                          optimize_graph, output_collection_name)