exporter.py 22.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions to export object detection inference graph."""
import os
Vivek Rathod's avatar
Vivek Rathod committed
18
import tempfile
19
import tensorflow as tf
20
from tensorflow.contrib.quantize.python import graph_matcher
21
from tensorflow.core.protobuf import saver_pb2
22
from tensorflow.python.tools import freeze_graph  # pylint: disable=g-direct-tensorflow-import
23
from object_detection.builders import graph_rewriter_builder
24
25
26
from object_detection.builders import model_builder
from object_detection.core import standard_fields as fields
from object_detection.data_decoders import tf_example_decoder
27
from object_detection.utils import config_util
28
from object_detection.utils import shape_utils
29
30
31

slim = tf.contrib.slim

32
freeze_graph_with_def_protos = freeze_graph.freeze_graph_with_def_protos
33
34


35
36
37
38
39
40
41
42
def rewrite_nn_resize_op(is_quantized=False):
  """Replaces a custom nearest-neighbor resize op with the Tensorflow version.

  Some graphs use this custom version for TPU-compatibility.

  Args:
    is_quantized: True if the default graph is quantized.
  """
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
  def remove_nn():
    """Remove nearest neighbor upsampling structure and replace with TF op."""
    input_pattern = graph_matcher.OpTypePattern(
        'FakeQuantWithMinMaxVars' if is_quantized else '*')
    stack_1_pattern = graph_matcher.OpTypePattern(
        'Pack', inputs=[input_pattern, input_pattern], ordered_inputs=False)
    stack_2_pattern = graph_matcher.OpTypePattern(
        'Pack', inputs=[stack_1_pattern, stack_1_pattern], ordered_inputs=False)
    reshape_pattern = graph_matcher.OpTypePattern(
        'Reshape', inputs=[stack_2_pattern, 'Const'], ordered_inputs=False)
    consumer_pattern = graph_matcher.OpTypePattern(
        'Add|AddV2|Max|Mul', inputs=[reshape_pattern, '*'],
        ordered_inputs=False)

    match_counter = 0
    matcher = graph_matcher.GraphMatcher(consumer_pattern)
    for match in matcher.match_graph(tf.get_default_graph()):
      match_counter += 1
      projection_op = match.get_op(input_pattern)
      reshape_op = match.get_op(reshape_pattern)
      consumer_op = match.get_op(consumer_pattern)
      nn_resize = tf.image.resize_nearest_neighbor(
          projection_op.outputs[0],
          reshape_op.outputs[0].shape.dims[1:3],
          align_corners=False,
          name=os.path.split(reshape_op.name)[0] + '/resize_nearest_neighbor')

      for index, op_input in enumerate(consumer_op.inputs):
        if op_input == reshape_op.outputs[0]:
          consumer_op._update_input(index, nn_resize)  # pylint: disable=protected-access
          break

    tf.logging.info('Found and fixed {} matches'.format(match_counter))
    return match_counter

  # Applying twice because both inputs to Add could be NN pattern
  total_removals = 0
  while remove_nn():
    total_removals += 1
    # This number is chosen based on the nas-fpn architecture.
    if total_removals > 4:
      raise ValueError('Graph removal encountered a infinite loop.')
85
86


Vivek Rathod's avatar
Vivek Rathod committed
87
88
def replace_variable_values_with_moving_averages(graph,
                                                 current_checkpoint_file,
89
90
                                                 new_checkpoint_file,
                                                 no_ema_collection=None):
Vivek Rathod's avatar
Vivek Rathod committed
91
  """Replaces variable values in the checkpoint with their moving averages.
92

Vivek Rathod's avatar
Vivek Rathod committed
93
94
95
96
97
98
99
100
101
  If the current checkpoint has shadow variables maintaining moving averages of
  the variables defined in the graph, this function generates a new checkpoint
  where the variables contain the values of their moving averages.

  Args:
    graph: a tf.Graph object.
    current_checkpoint_file: a checkpoint containing both original variables and
      their moving averages.
    new_checkpoint_file: file path to write a new checkpoint.
102
103
    no_ema_collection: A list of namescope substrings to match the variables
      to eliminate EMA.
Vivek Rathod's avatar
Vivek Rathod committed
104
105
106
107
  """
  with graph.as_default():
    variable_averages = tf.train.ExponentialMovingAverage(0.0)
    ema_variables_to_restore = variable_averages.variables_to_restore()
108
109
    ema_variables_to_restore = config_util.remove_unecessary_ema(
        ema_variables_to_restore, no_ema_collection)
Vivek Rathod's avatar
Vivek Rathod committed
110
111
112
113
114
115
116
117
118
119
120
121
122
    with tf.Session() as sess:
      read_saver = tf.train.Saver(ema_variables_to_restore)
      read_saver.restore(sess, current_checkpoint_file)
      write_saver = tf.train.Saver()
      write_saver.save(sess, new_checkpoint_file)


def _image_tensor_input_placeholder(input_shape=None):
  """Returns input placeholder and a 4-D uint8 image tensor."""
  if input_shape is None:
    input_shape = (None, None, None, 3)
  input_tensor = tf.placeholder(
      dtype=tf.uint8, shape=input_shape, name='image_tensor')
Derek Chow's avatar
Derek Chow committed
123
  return input_tensor, input_tensor
124

125

126
def _tf_example_input_placeholder(input_shape=None):
Derek Chow's avatar
Derek Chow committed
127
128
  """Returns input that accepts a batch of strings with tf examples.

129
130
131
  Args:
    input_shape: the shape to resize the output decoded images to (optional).

Derek Chow's avatar
Derek Chow committed
132
  Returns:
Vivek Rathod's avatar
Vivek Rathod committed
133
    a tuple of input placeholder and the output decoded images.
Derek Chow's avatar
Derek Chow committed
134
  """
135
136
137
138
139
140
  batch_tf_example_placeholder = tf.placeholder(
      tf.string, shape=[None], name='tf_example')
  def decode(tf_example_string_tensor):
    tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
        tf_example_string_tensor)
    image_tensor = tensor_dict[fields.InputDataFields.image]
141
142
    if input_shape is not None:
      image_tensor = tf.image.resize(image_tensor, input_shape[1:3])
143
    return image_tensor
Derek Chow's avatar
Derek Chow committed
144
  return (batch_tf_example_placeholder,
145
146
147
148
149
150
          shape_utils.static_or_dynamic_map_fn(
              decode,
              elems=batch_tf_example_placeholder,
              dtype=tf.uint8,
              parallel_iterations=32,
              back_prop=False))
151
152


153
def _encoded_image_string_tensor_input_placeholder(input_shape=None):
Derek Chow's avatar
Derek Chow committed
154
155
  """Returns input that accepts a batch of PNG or JPEG strings.

156
157
158
  Args:
    input_shape: the shape to resize the output decoded images to (optional).

Derek Chow's avatar
Derek Chow committed
159
  Returns:
Vivek Rathod's avatar
Vivek Rathod committed
160
    a tuple of input placeholder and the output decoded images.
Derek Chow's avatar
Derek Chow committed
161
  """
162
163
164
165
166
167
168
169
  batch_image_str_placeholder = tf.placeholder(
      dtype=tf.string,
      shape=[None],
      name='encoded_image_string_tensor')
  def decode(encoded_image_string_tensor):
    image_tensor = tf.image.decode_image(encoded_image_string_tensor,
                                         channels=3)
    image_tensor.set_shape((None, None, 3))
170
171
    if input_shape is not None:
      image_tensor = tf.image.resize(image_tensor, input_shape[1:3])
172
    return image_tensor
Derek Chow's avatar
Derek Chow committed
173
174
175
176
177
178
179
  return (batch_image_str_placeholder,
          tf.map_fn(
              decode,
              elems=batch_image_str_placeholder,
              dtype=tf.uint8,
              parallel_iterations=32,
              back_prop=False))
180
181


182
input_placeholder_fn_map = {
183
184
185
    'image_tensor': _image_tensor_input_placeholder,
    'encoded_image_string_tensor':
    _encoded_image_string_tensor_input_placeholder,
186
187
188
189
    'tf_example': _tf_example_input_placeholder,
}


190
191
def add_output_tensor_nodes(postprocessed_tensors,
                            output_collection_name='inference_op'):
192
193
194
195
196
197
198
199
  """Adds output nodes for detection boxes and scores.

  Adds the following nodes for output tensors -
    * num_detections: float32 tensor of shape [batch_size].
    * detection_boxes: float32 tensor of shape [batch_size, num_boxes, 4]
      containing detected boxes.
    * detection_scores: float32 tensor of shape [batch_size, num_boxes]
      containing scores for the detected boxes.
200
201
202
    * detection_multiclass_scores: (Optional) float32 tensor of shape
      [batch_size, num_boxes, num_classes_with_background] for containing class
      score distribution for detected boxes including background if any.
pkulzc's avatar
pkulzc committed
203
204
205
206
    * detection_features: (Optional) float32 tensor of shape
      [batch, num_boxes, roi_height, roi_width, depth]
      containing classifier features
      for each detected box
207
208
    * detection_classes: float32 tensor of shape [batch_size, num_boxes]
      containing class predictions for the detected boxes.
209
210
211
    * detection_keypoints: (Optional) float32 tensor of shape
      [batch_size, num_boxes, num_keypoints, 2] containing keypoints for each
      detection box.
212
213
214
    * detection_masks: (Optional) float32 tensor of shape
      [batch_size, num_boxes, mask_height, mask_width] containing masks for each
      detection box.
215
216
217
218
219

  Args:
    postprocessed_tensors: a dictionary containing the following fields
      'detection_boxes': [batch, max_detections, 4]
      'detection_scores': [batch, max_detections]
220
221
      'detection_multiclass_scores': [batch, max_detections,
        num_classes_with_background]
pkulzc's avatar
pkulzc committed
222
      'detection_features': [batch, num_boxes, roi_height, roi_width, depth]
223
      'detection_classes': [batch, max_detections]
224
225
      'detection_masks': [batch, max_detections, mask_height, mask_width]
        (optional).
226
227
      'detection_keypoints': [batch, max_detections, num_keypoints, 2]
        (optional).
228
      'num_detections': [batch]
229
    output_collection_name: Name of collection to add output tensors to.
230
231
232

  Returns:
    A tensor dict containing the added output tensor nodes.
233
  """
234
  detection_fields = fields.DetectionResultFields
235
  label_id_offset = 1
236
237
  boxes = postprocessed_tensors.get(detection_fields.detection_boxes)
  scores = postprocessed_tensors.get(detection_fields.detection_scores)
238
239
  multiclass_scores = postprocessed_tensors.get(
      detection_fields.detection_multiclass_scores)
pkulzc's avatar
pkulzc committed
240
241
  box_classifier_features = postprocessed_tensors.get(
      detection_fields.detection_features)
242
243
  raw_boxes = postprocessed_tensors.get(detection_fields.raw_detection_boxes)
  raw_scores = postprocessed_tensors.get(detection_fields.raw_detection_scores)
244
245
  classes = postprocessed_tensors.get(
      detection_fields.detection_classes) + label_id_offset
246
  keypoints = postprocessed_tensors.get(detection_fields.detection_keypoints)
247
248
  masks = postprocessed_tensors.get(detection_fields.detection_masks)
  num_detections = postprocessed_tensors.get(detection_fields.num_detections)
249
  outputs = {}
250
251
252
253
  outputs[detection_fields.detection_boxes] = tf.identity(
      boxes, name=detection_fields.detection_boxes)
  outputs[detection_fields.detection_scores] = tf.identity(
      scores, name=detection_fields.detection_scores)
254
255
256
  if multiclass_scores is not None:
    outputs[detection_fields.detection_multiclass_scores] = tf.identity(
        multiclass_scores, name=detection_fields.detection_multiclass_scores)
pkulzc's avatar
pkulzc committed
257
258
259
260
  if box_classifier_features is not None:
    outputs[detection_fields.detection_features] = tf.identity(
        box_classifier_features,
        name=detection_fields.detection_features)
261
262
263
264
  outputs[detection_fields.detection_classes] = tf.identity(
      classes, name=detection_fields.detection_classes)
  outputs[detection_fields.num_detections] = tf.identity(
      num_detections, name=detection_fields.num_detections)
265
266
267
268
269
270
  if raw_boxes is not None:
    outputs[detection_fields.raw_detection_boxes] = tf.identity(
        raw_boxes, name=detection_fields.raw_detection_boxes)
  if raw_scores is not None:
    outputs[detection_fields.raw_detection_scores] = tf.identity(
        raw_scores, name=detection_fields.raw_detection_scores)
271
272
273
  if keypoints is not None:
    outputs[detection_fields.detection_keypoints] = tf.identity(
        keypoints, name=detection_fields.detection_keypoints)
274
  if masks is not None:
275
276
    outputs[detection_fields.detection_masks] = tf.identity(
        masks, name=detection_fields.detection_masks)
277
278
  for output_key in outputs:
    tf.add_to_collection(output_collection_name, outputs[output_key])
279

280
  return outputs
281
282


283
284
285
286
def write_saved_model(saved_model_path,
                      frozen_graph_def,
                      inputs,
                      outputs):
287
288
289
290
291
292
293
294
295
  """Writes SavedModel to disk.

  If checkpoint_path is not None bakes the weights into the graph thereby
  eliminating the need of checkpoint files during inference. If the model
  was trained with moving averages, setting use_moving_averages to true
  restores the moving averages, otherwise the original set of variables
  is restored.

  Args:
296
297
    saved_model_path: Path to write SavedModel.
    frozen_graph_def: tf.GraphDef holding frozen graph.
298
    inputs: The input placeholder tensor.
299
300
301
    outputs: A tensor dictionary containing the outputs of a DetectionModel.
  """
  with tf.Graph().as_default():
302
    with tf.Session() as sess:
303

304
      tf.import_graph_def(frozen_graph_def, name='')
305

306
      builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path)
307
308
309
310
311
312
313
314
315
316
317

      tensor_info_inputs = {
          'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
      tensor_info_outputs = {}
      for k, v in outputs.items():
        tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)

      detection_signature = (
          tf.saved_model.signature_def_utils.build_signature_def(
              inputs=tensor_info_inputs,
              outputs=tensor_info_outputs,
318
319
              method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
          ))
320
321

      builder.add_meta_graph_and_variables(
322
323
          sess,
          [tf.saved_model.tag_constants.SERVING],
324
          signature_def_map={
325
326
              tf.saved_model.signature_constants
              .DEFAULT_SERVING_SIGNATURE_DEF_KEY:
327
328
329
330
331
332
                  detection_signature,
          },
      )
      builder.save()


333
334
335
336
337
def write_graph_and_checkpoint(inference_graph_def,
                               model_path,
                               input_saver_def,
                               trained_checkpoint_prefix):
  """Writes the graph and the checkpoint into disk."""
338
339
340
341
  for node in inference_graph_def.node:
    node.device = ''
  with tf.Graph().as_default():
    tf.import_graph_def(inference_graph_def, name='')
342
343
344
    with tf.Session() as sess:
      saver = tf.train.Saver(
          saver_def=input_saver_def, save_relative_paths=True)
345
346
347
348
      saver.restore(sess, trained_checkpoint_prefix)
      saver.save(sess, model_path)


349
350
def _get_outputs_from_inputs(input_tensors, detection_model,
                             output_collection_name):
351
  inputs = tf.cast(input_tensors, dtype=tf.float32)
352
353
354
355
356
  preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)
  output_tensors = detection_model.predict(
      preprocessed_inputs, true_image_shapes)
  postprocessed_tensors = detection_model.postprocess(
      output_tensors, true_image_shapes)
357
358
  return add_output_tensor_nodes(postprocessed_tensors,
                                 output_collection_name)
359
360


361
362
def build_detection_graph(input_type, detection_model, input_shape,
                          output_collection_name, graph_hook_fn):
363
364
365
366
367
  """Build the detection graph."""
  if input_type not in input_placeholder_fn_map:
    raise ValueError('Unknown input type: {}'.format(input_type))
  placeholder_args = {}
  if input_shape is not None:
368
369
370
371
372
    if (input_type != 'image_tensor' and
        input_type != 'encoded_image_string_tensor' and
        input_type != 'tf_example'):
      raise ValueError('Can only specify input shape for `image_tensor`, '
                       '`encoded_image_string_tensor`, or `tf_example` '
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
                       'inputs.')
    placeholder_args['input_shape'] = input_shape
  placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
      **placeholder_args)
  outputs = _get_outputs_from_inputs(
      input_tensors=input_tensors,
      detection_model=detection_model,
      output_collection_name=output_collection_name)

  # Add global step to the graph.
  slim.get_or_create_global_step()

  if graph_hook_fn: graph_hook_fn()

  return outputs, placeholder_tensor


390
391
392
def _export_inference_graph(input_type,
                            detection_model,
                            use_moving_averages,
393
394
                            trained_checkpoint_prefix,
                            output_directory,
Vivek Rathod's avatar
Vivek Rathod committed
395
396
                            additional_output_tensor_names=None,
                            input_shape=None,
397
                            output_collection_name='inference_op',
398
                            graph_hook_fn=None,
399
400
                            write_inference_graph=False,
                            temp_checkpoint_prefix=''):
401
  """Export helper."""
402
403
404
405
406
407
  tf.gfile.MakeDirs(output_directory)
  frozen_graph_path = os.path.join(output_directory,
                                   'frozen_inference_graph.pb')
  saved_model_path = os.path.join(output_directory, 'saved_model')
  model_path = os.path.join(output_directory, 'model.ckpt')

408
  outputs, placeholder_tensor = build_detection_graph(
409
410
411
412
413
      input_type=input_type,
      detection_model=detection_model,
      input_shape=input_shape,
      output_collection_name=output_collection_name,
      graph_hook_fn=graph_hook_fn)
414

415
  profile_inference_graph(tf.get_default_graph())
416
  saver_kwargs = {}
417
  if use_moving_averages:
418
419
420
421
422
423
424
    if not temp_checkpoint_prefix:
      # This check is to be compatible with both version of SaverDef.
      if os.path.isfile(trained_checkpoint_prefix):
        saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
        temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name
      else:
        temp_checkpoint_prefix = tempfile.mkdtemp()
Vivek Rathod's avatar
Vivek Rathod committed
425
426
    replace_variable_values_with_moving_averages(
        tf.get_default_graph(), trained_checkpoint_prefix,
427
428
        temp_checkpoint_prefix)
    checkpoint_to_use = temp_checkpoint_prefix
429
  else:
Vivek Rathod's avatar
Vivek Rathod committed
430
431
    checkpoint_to_use = trained_checkpoint_prefix

432
  saver = tf.train.Saver(**saver_kwargs)
433
434
  input_saver_def = saver.as_saver_def()

435
  write_graph_and_checkpoint(
436
437
438
      inference_graph_def=tf.get_default_graph().as_graph_def(),
      model_path=model_path,
      input_saver_def=input_saver_def,
Vivek Rathod's avatar
Vivek Rathod committed
439
      trained_checkpoint_prefix=checkpoint_to_use)
440
441
442
443
444
445
  if write_inference_graph:
    inference_graph_def = tf.get_default_graph().as_graph_def()
    inference_graph_path = os.path.join(output_directory,
                                        'inference_graph.pbtxt')
    for node in inference_graph_def.node:
      node.device = ''
446
    with tf.gfile.GFile(inference_graph_path, 'wb') as f:
447
      f.write(str(inference_graph_def))
Vivek Rathod's avatar
Vivek Rathod committed
448
449
450
451
452

  if additional_output_tensor_names is not None:
    output_node_names = ','.join(outputs.keys()+additional_output_tensor_names)
  else:
    output_node_names = ','.join(outputs.keys())
453

454
  frozen_graph_def = freeze_graph.freeze_graph_with_def_protos(
455
456
      input_graph_def=tf.get_default_graph().as_graph_def(),
      input_saver_def=input_saver_def,
Vivek Rathod's avatar
Vivek Rathod committed
457
458
      input_checkpoint=checkpoint_to_use,
      output_node_names=output_node_names,
459
460
      restore_op_name='save/restore_all',
      filename_tensor_name='save/Const:0',
461
      output_graph=frozen_graph_path,
462
463
      clear_devices=True,
      initializer_nodes='')
464

465
466
  write_saved_model(saved_model_path, frozen_graph_def,
                    placeholder_tensor, outputs)
467
468


469
470
471
472
def export_inference_graph(input_type,
                           pipeline_config,
                           trained_checkpoint_prefix,
                           output_directory,
Vivek Rathod's avatar
Vivek Rathod committed
473
474
                           input_shape=None,
                           output_collection_name='inference_op',
475
476
                           additional_output_tensor_names=None,
                           write_inference_graph=False):
477
478
479
  """Exports inference graph for the model specified in the pipeline config.

  Args:
480
481
    input_type: Type of input for the graph. Can be one of ['image_tensor',
      'encoded_image_string_tensor', 'tf_example'].
482
    pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
483
484
    trained_checkpoint_prefix: Path to the trained checkpoint file.
    output_directory: Path to write outputs.
Vivek Rathod's avatar
Vivek Rathod committed
485
486
    input_shape: Sets a fixed shape for an `image_tensor` input. If not
      specified, will default to [None, None, None, 3].
487
488
    output_collection_name: Name of collection to add output tensors to.
      If None, does not add output tensors to a collection.
Vivek Rathod's avatar
Vivek Rathod committed
489
    additional_output_tensor_names: list of additional output
490
      tensors to include in the frozen graph.
491
    write_inference_graph: If true, writes inference graph to disk.
492
493
494
  """
  detection_model = model_builder.build(pipeline_config.model,
                                        is_training=False)
495
496
497
498
499
  graph_rewriter_fn = None
  if pipeline_config.HasField('graph_rewriter'):
    graph_rewriter_config = pipeline_config.graph_rewriter
    graph_rewriter_fn = graph_rewriter_builder.build(graph_rewriter_config,
                                                     is_training=False)
500
501
502
503
504
505
506
507
508
  _export_inference_graph(
      input_type,
      detection_model,
      pipeline_config.eval_config.use_moving_averages,
      trained_checkpoint_prefix,
      output_directory,
      additional_output_tensor_names,
      input_shape,
      output_collection_name,
509
      graph_hook_fn=graph_rewriter_fn,
510
      write_inference_graph=write_inference_graph)
511
  pipeline_config.eval_config.use_moving_averages = False
512
  config_util.save_pipeline_config(pipeline_config, output_directory)
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543


def profile_inference_graph(graph):
  """Profiles the inference graph.

  Prints model parameters and computation FLOPs given an inference graph.
  BatchNorms are excluded from the parameter count due to the fact that
  BatchNorms are usually folded. BatchNorm, Initializer, Regularizer
  and BiasAdd are not considered in FLOP count.

  Args:
    graph: the inference graph.
  """
  tfprof_vars_option = (
      tf.contrib.tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
  tfprof_flops_option = tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS

  # Batchnorm is usually folded during inference.
  tfprof_vars_option['trim_name_regexes'] = ['.*BatchNorm.*']
  # Initializer and Regularizer are only used in training.
  tfprof_flops_option['trim_name_regexes'] = [
      '.*BatchNorm.*', '.*Initializer.*', '.*Regularizer.*', '.*BiasAdd.*'
  ]

  tf.contrib.tfprof.model_analyzer.print_model_analysis(
      graph,
      tfprof_options=tfprof_vars_option)

  tf.contrib.tfprof.model_analyzer.print_model_analysis(
      graph,
      tfprof_options=tfprof_flops_option)