export_model.py 7.44 KB
Newer Older
1
# Lint as: python2, python3
yukun's avatar
yukun committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Exports trained model to TensorFlow frozen graph."""

import os
import tensorflow as tf
20
21
from google3.learning.brain.contrib import quantize as contrib_quantize
from google3.learning.brain.contrib import slim as contrib_slim
yukun's avatar
yukun committed
22
23
24
25
26
27

from tensorflow.python.tools import freeze_graph
from deeplab import common
from deeplab import input_preprocess
from deeplab import model

28
slim = contrib_slim
yukun's avatar
yukun committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
flags = tf.app.flags

FLAGS = flags.FLAGS

flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path')

flags.DEFINE_string('export_path', None,
                    'Path to output Tensorflow frozen graph.')

flags.DEFINE_integer('num_classes', 21, 'Number of classes.')

flags.DEFINE_multi_integer('crop_size', [513, 513],
                           'Crop size [height, width].')

# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
# one could use different atrous_rates/output_stride during training/evaluation.
flags.DEFINE_multi_integer('atrous_rates', None,
                           'Atrous rates for atrous spatial pyramid pooling.')

flags.DEFINE_integer('output_stride', 8,
                     'The ratio of input to output spatial resolution.')

# Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale inference.
flags.DEFINE_multi_float('inference_scales', [1.0],
                         'The scales to resize images for inference.')

flags.DEFINE_bool('add_flipped_images', False,
                  'Add flipped images during inference or not.')

59
60
61
62
flags.DEFINE_integer(
    'quantize_delay_step', -1,
    'Steps to start quantized training. If < 0, will not quantize model.')

63
64
65
flags.DEFINE_bool('save_inference_graph', False,
                  'Save inference graph in text proto.')

yukun's avatar
yukun committed
66
67
68
# Input name of the exported model.
_INPUT_NAME = 'ImageTensor'

69
# Output name of the exported predictions.
yukun's avatar
yukun committed
70
_OUTPUT_NAME = 'SemanticPredictions'
71
_RAW_OUTPUT_NAME = 'RawSemanticPredictions'
yukun's avatar
yukun committed
72

73
74
75
76
# Output name of the exported probabilities.
_OUTPUT_PROB_NAME = 'SemanticProbabilities'
_RAW_OUTPUT_PROB_NAME = 'RawSemanticProbabilities'

yukun's avatar
yukun committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

def _create_input_tensors():
  """Creates and prepares input tensors for DeepLab model.

  This method creates a 4-D uint8 image tensor 'ImageTensor' with shape
  [1, None, None, 3]. The actual input tensor name to use during inference is
  'ImageTensor:0'.

  Returns:
    image: Preprocessed 4-D float32 tensor with shape [1, crop_height,
      crop_width, 3].
    original_image_size: Original image shape tensor [height, width].
    resized_image_size: Resized image shape tensor [height, width].
  """
  # input_preprocess takes 4-D image tensor as input.
  input_image = tf.placeholder(tf.uint8, [1, None, None, 3], name=_INPUT_NAME)
  original_image_size = tf.shape(input_image)[1:3]

  # Squeeze the dimension in axis=0 since `preprocess_image_and_label` assumes
  # image to be 3-D.
  image = tf.squeeze(input_image, axis=0)
  resized_image, image, _ = input_preprocess.preprocess_image_and_label(
      image,
      label=None,
      crop_height=FLAGS.crop_size[0],
      crop_width=FLAGS.crop_size[1],
      min_resize_value=FLAGS.min_resize_value,
      max_resize_value=FLAGS.max_resize_value,
      resize_factor=FLAGS.resize_factor,
      is_training=False,
      model_variant=FLAGS.model_variant)
  resized_image_size = tf.shape(resized_image)[:2]

  # Expand the dimension in axis=0, since the following operations assume the
  # image to be 4-D.
  image = tf.expand_dims(image, 0)

  return image, original_image_size, resized_image_size


def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.logging.info('Prepare to export model to: %s', FLAGS.export_path)

  with tf.Graph().as_default():
    image, image_size, resized_image_size = _create_input_tensors()

    model_options = common.ModelOptions(
        outputs_to_num_classes={common.OUTPUT_TYPE: FLAGS.num_classes},
        crop_size=FLAGS.crop_size,
        atrous_rates=FLAGS.atrous_rates,
        output_stride=FLAGS.output_stride)

    if tuple(FLAGS.inference_scales) == (1.0,):
      tf.logging.info('Exported model performs single-scale inference.')
      predictions = model.predict_labels(
          image,
          model_options=model_options,
          image_pyramid=FLAGS.image_pyramid)
    else:
      tf.logging.info('Exported model performs multi-scale inference.')
138
139
140
      if FLAGS.quantize_delay_step >= 0:
        raise ValueError(
            'Quantize mode is not supported with multi-scale test.')
yukun's avatar
yukun committed
141
142
143
144
145
      predictions = model.predict_labels_multi_scale(
          image,
          model_options=model_options,
          eval_scales=FLAGS.inference_scales,
          add_flipped_images=FLAGS.add_flipped_images)
146
147
148
    raw_predictions = tf.identity(
        tf.cast(predictions[common.OUTPUT_TYPE], tf.float32),
        _RAW_OUTPUT_NAME)
149
150
151
152
    raw_probabilities = tf.identity(
        predictions[common.OUTPUT_TYPE + model.PROB_SUFFIX],
        _RAW_OUTPUT_PROB_NAME)

yukun's avatar
yukun committed
153
    # Crop the valid regions from the predictions.
154
155
156
157
158
    semantic_predictions = raw_predictions[
        :, :resized_image_size[0], :resized_image_size[1]]
    semantic_probabilities = raw_probabilities[
        :, :resized_image_size[0], :resized_image_size[1]]

yukun's avatar
yukun committed
159
160
161
162
163
164
165
166
167
    # Resize back the prediction to the original image size.
    def _resize_label(label, label_size):
      # Expand dimension of label to [1, height, width, 1] for resize operation.
      label = tf.expand_dims(label, 3)
      resized_label = tf.image.resize_images(
          label,
          label_size,
          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
          align_corners=True)
168
      return tf.cast(tf.squeeze(resized_label, 3), tf.int32)
yukun's avatar
yukun committed
169
170
171
    semantic_predictions = _resize_label(semantic_predictions, image_size)
    semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME)

172
173
174
175
    semantic_probabilities = tf.image.resize_bilinear(
        semantic_probabilities, image_size, align_corners=True,
        name=_OUTPUT_PROB_NAME)

176
    if FLAGS.quantize_delay_step >= 0:
177
      contrib_quantize.create_eval_graph()
178
179

    saver = tf.train.Saver(tf.all_variables())
yukun's avatar
yukun committed
180

181
182
183
    dirname = os.path.dirname(FLAGS.export_path)
    tf.gfile.MakeDirs(dirname)
    graph_def = tf.get_default_graph().as_graph_def(add_shapes=True)
yukun's avatar
yukun committed
184
    freeze_graph.freeze_graph_with_def_protos(
185
        graph_def,
yukun's avatar
yukun committed
186
187
        saver.as_saver_def(),
        FLAGS.checkpoint_path,
188
        _OUTPUT_NAME + ',' + _OUTPUT_PROB_NAME,
yukun's avatar
yukun committed
189
190
191
192
193
194
        restore_op_name=None,
        filename_tensor_name=None,
        output_graph=FLAGS.export_path,
        clear_devices=True,
        initializer_nodes=None)

195
196
197
    if FLAGS.save_inference_graph:
      tf.train.write_graph(graph_def, dirname, 'inference_graph.pbtxt')

yukun's avatar
yukun committed
198
199
200
201
202

if __name__ == '__main__':
  flags.mark_flag_as_required('checkpoint_path')
  flags.mark_flag_as_required('export_path')
  tf.app.run()