"examples/vscode:/vscode.git/clone" did not exist on "0240d4191adfe22e66e43826261795c0f1217651"
Commit 49075e50 authored by Richard Brooks's avatar Richard Brooks Committed by Menglong Zhu
Browse files

Lstm object detection improvements (#7379)

* Replace google3.pyglib modules with tf and absl

This now matches train.py and provides more publicly available libraries.

* Add example pipeline config for SSD Interleaved V2 Model.

Compiled from model_builder_test.py and lstm_ssd_mobilenet_v1_imagenet.config,
Removed data augmentation and tranfer learning (i.e. training from checkpoint) due to errors I was seeing when trying to run with it.

* Remove unused tfrecord creation.

This was also incorrectly specified, as the keys differed from the TFSequenceExample parser.

* correct key specified in docstring

* add tflite frozen graph exporter (cli and lib).

* add tflite model exporter

* add script to test the tflite model

* add mode export documentation

* correct docstring

* rename export files to be unique across detection research work

* correct number of channels for grayscale

* add and correct copyright
parent 03b4a0af
......@@ -32,3 +32,8 @@ https://scholar.googleusercontent.com/scholar.bib?q=info:rLqvkztmWYgJ:scholar.go
* yinxiao@google.com
* menglong@google.com
* yongzhe@google.com
## Table of Contents
* <a href='g3doc/exporting_models.md'>Exporting a trained model</a>
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# For training on Imagenet Video with LSTM Interleaved Mobilenet V2
[lstm_object_detection.protos.lstm_model] {
train_unroll_length: 4
eval_unroll_length: 4
lstm_state_depth: 320
depth_multipliers: 1.4
depth_multipliers: 0.35
pre_bottleneck: true
low_res: true
train_interleave_method: 'RANDOM_SKIP_SMALL'
eval_interleave_method: 'SKIP3'
}
model {
ssd {
num_classes: 30 # Num of class for imagenet vid dataset.
box_coder {
faster_rcnn_box_coder {
y_scale: 10.0
x_scale: 10.0
height_scale: 5.0
width_scale: 5.0
}
}
matcher {
argmax_matcher {
matched_threshold: 0.5
unmatched_threshold: 0.5
ignore_thresholds: false
negatives_lower_than_unmatched: true
force_match_for_each_row: true
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
num_layers: 5
min_scale: 0.2
max_scale: 0.95
aspect_ratios: 1.0
aspect_ratios: 2.0
aspect_ratios: 0.5
aspect_ratios: 3.0
aspect_ratios: 0.3333
}
}
image_resizer {
fixed_shape_resizer {
height: 320
width: 320
}
}
box_predictor {
convolutional_box_predictor {
min_depth: 0
max_depth: 0
num_layers_before_predictor: 3
use_dropout: false
dropout_keep_probability: 0.8
kernel_size: 3
box_code_size: 4
apply_sigmoid_to_scores: false
use_depthwise: true
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.9997,
epsilon: 0.001,
}
}
}
}
feature_extractor {
type: 'lstm_ssd_interleaved_mobilenet_v2'
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.9997,
epsilon: 0.001,
}
}
}
loss {
classification_loss {
weighted_sigmoid {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
hard_example_miner {
num_hard_examples: 3000
iou_threshold: 0.99
loss_type: CLASSIFICATION
max_negatives_per_positive: 3
min_negatives_per_image: 0
}
classification_weight: 1.0
localization_weight: 4.0
}
normalize_loss_by_num_matches: true
post_processing {
batch_non_max_suppression {
score_threshold: -20.0
iou_threshold: 0.5
max_detections_per_class: 100
max_total_detections: 100
}
score_converter: SIGMOID
}
}
}
train_config: {
batch_size: 8
optimizer {
use_moving_average: false
rms_prop_optimizer: {
learning_rate: {
exponential_decay_learning_rate {
initial_learning_rate: 0.002
decay_steps: 200000
decay_factor: 0.95
}
}
momentum_optimizer_value: 0.9
decay: 0.9
epsilon: 1.0
}
}
gradient_clipping_by_norm: 10.0
batch_queue_capacity: 12
prefetch_queue_capacity: 4
}
train_input_reader: {
shuffle_buffer_size: 32
queue_capacity: 12
prefetch_size: 12
min_after_dequeue: 4
label_map_path: "path/to/label_map"
external_input_reader {
[lstm_object_detection.protos.GoogleInputReader.google_input_reader] {
tf_record_video_input_reader: {
input_path: '/data/lstm_detection/tfrecords/test.tfrecord'
data_type: TF_SEQUENCE_EXAMPLE
video_length: 4
}
}
}
}
eval_config: {
metrics_set: "coco_evaluation_all_frames"
use_moving_averages: true
min_score_threshold: 0.5
max_num_boxes_to_visualize: 300
visualize_groundtruth_boxes: true
groundtruth_box_visualization_color: "red"
}
eval_input_reader {
label_map_path: "path/to/label_map"
shuffle: true
num_epochs: 1
num_parallel_batches: 1
num_readers: 1
external_input_reader {
[lstm_object_detection.protos.GoogleInputReader.google_input_reader] {
tf_record_video_input_reader: {
input_path: "path/to/sequence_example/data"
data_type: TF_SEQUENCE_EXAMPLE
video_length: 10
}
}
}
}
eval_input_reader: {
label_map_path: "path/to/label_map"
external_input_reader {
[lstm_object_detection.protos.GoogleInputReader.google_input_reader] {
tf_record_video_input_reader: {
input_path: "path/to/sequence_example/data"
data_type: TF_SEQUENCE_EXAMPLE
video_length: 4
}
}
}
shuffle: true
num_readers: 1
}
......@@ -27,8 +27,6 @@ import functools
import os
import tensorflow as tf
from google.protobuf import text_format
from google3.pyglib import app
from google3.pyglib import flags
from lstm_object_detection import evaluator
from lstm_object_detection import model_builder
from lstm_object_detection.inputs import seq_dataset_builder
......@@ -107,4 +105,4 @@ def main(unused_argv):
FLAGS.checkpoint_dir, FLAGS.eval_dir)
if __name__ == '__main__':
app.run()
tf.app.run()
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exports an LSTM detection model to use with tf-lite.
Outputs file:
* A tflite compatible frozen graph - $output_directory/tflite_graph.pb
The exported graph has the following input and output nodes.
Inputs:
'input_video_tensor': a float32 tensor of shape
[unroll_length, height, width, 3] containing the normalized input image.
Note that the height and width must be compatible with the height and
width configured in the fixed_shape_image resizer options in the pipeline
config proto.
Outputs:
If add_postprocessing_op is true: frozen graph adds a
TFLite_Detection_PostProcess custom op node has four outputs:
detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
locations
detection_classes: a float32 tensor of shape [1, num_boxes]
with class indices
detection_scores: a float32 tensor of shape [1, num_boxes]
with class scores
num_boxes: a float32 tensor of size 1 containing the number of detected boxes
else:
the graph has three outputs:
'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
containing the encoded box predictions.
'raw_outputs/class_predictions': a float32 tensor of shape
[1, num_anchors, num_classes] containing the class scores for each anchor
after applying score conversion.
'anchors': a float32 constant tensor of shape [num_anchors, 4]
containing the anchor boxes.
Example Usage:
--------------
python lstm_object_detection/export_tflite_lstd_graph.py \
--pipeline_config_path path/to/lstm_pipeline.config \
--trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory
The expected output would be in the directory
path/to/exported_model_directory (which is created if it does not exist)
with contents:
- tflite_graph.pbtxt
- tflite_graph.pb
Config overrides (see the `config_override` flag) are text protobufs
(also of type pipeline_pb2.TrainEvalPipelineConfig) which are used to override
certain fields in the provided pipeline_config_path. These are useful for
making small changes to the inference graph that differ from the training or
eval config.
Example Usage (in which we change the NMS iou_threshold to be 0.5 and
NMS score_threshold to be 0.0):
python lstm_object_detection/export_tflite_lstd_graph.py \
--pipeline_config_path path/to/lstm_pipeline.config \
--trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory
--config_override " \
model{ \
ssd{ \
post_processing { \
batch_non_max_suppression { \
score_threshold: 0.0 \
iou_threshold: 0.5 \
} \
} \
} \
} \
"
"""
import tensorflow as tf
from lstm_object_detection.utils import config_util
from lstm_object_detection import export_tflite_lstd_graph_lib
flags = tf.app.flags
flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
flags.DEFINE_string(
'pipeline_config_path', None,
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file.')
flags.DEFINE_string('trained_checkpoint_prefix', None, 'Checkpoint prefix.')
flags.DEFINE_integer('max_detections', 10,
'Maximum number of detections (boxes) to show.')
flags.DEFINE_integer('max_classes_per_detection', 1,
'Maximum number of classes to output per detection box.')
flags.DEFINE_integer(
'detections_per_class', 100,
'Number of anchors used per class in Regular Non-Max-Suppression.')
flags.DEFINE_bool('add_postprocessing_op', True,
'Add TFLite custom op for postprocessing to the graph.')
flags.DEFINE_bool(
'use_regular_nms', False,
'Flag to set postprocessing op to use Regular NMS instead of Fast NMS.')
flags.DEFINE_string(
'config_override', '', 'pipeline_pb2.TrainEvalPipelineConfig '
'text proto to override pipeline_config_path.')
FLAGS = flags.FLAGS
def main(argv):
del argv # Unused.
flags.mark_flag_as_required('output_directory')
flags.mark_flag_as_required('pipeline_config_path')
flags.mark_flag_as_required('trained_checkpoint_prefix')
pipeline_config = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path)
export_tflite_lstd_graph_lib.export_tflite_graph(
pipeline_config, FLAGS.trained_checkpoint_prefix, FLAGS.output_directory,
FLAGS.add_postprocessing_op, FLAGS.max_detections,
FLAGS.max_classes_per_detection, use_regular_nms=FLAGS.use_regular_nms)
if __name__ == '__main__':
tf.app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exports detection models to use with tf-lite.
See export_tflite_lstd_graph.py for usage.
"""
import os
import tempfile
import numpy as np
import tensorflow as tf
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from object_detection import exporter
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import post_processing_builder
from object_detection.core import box_list
from lstm_object_detection import model_builder
_DEFAULT_NUM_CHANNELS = 3
_DEFAULT_NUM_COORD_BOX = 4
def get_const_center_size_encoded_anchors(anchors):
"""Exports center-size encoded anchors as a constant tensor.
Args:
anchors: a float32 tensor of shape [num_anchors, 4] containing the anchor
boxes
Returns:
encoded_anchors: a float32 constant tensor of shape [num_anchors, 4]
containing the anchor boxes.
"""
anchor_boxlist = box_list.BoxList(anchors)
y, x, h, w = anchor_boxlist.get_center_coordinates_and_sizes()
num_anchors = y.get_shape().as_list()
with tf.Session() as sess:
y_out, x_out, h_out, w_out = sess.run([y, x, h, w])
encoded_anchors = tf.constant(
np.transpose(np.stack((y_out, x_out, h_out, w_out))),
dtype=tf.float32,
shape=[num_anchors[0], _DEFAULT_NUM_COORD_BOX],
name='anchors')
return encoded_anchors
def append_postprocessing_op(frozen_graph_def,
max_detections,
max_classes_per_detection,
nms_score_threshold,
nms_iou_threshold,
num_classes,
scale_values,
detections_per_class=100,
use_regular_nms=False):
"""Appends postprocessing custom op.
Args:
frozen_graph_def: Frozen GraphDef for SSD model after freezing the
checkpoint
max_detections: Maximum number of detections (boxes) to show
max_classes_per_detection: Number of classes to display per detection
nms_score_threshold: Score threshold used in Non-maximal suppression in
post-processing
nms_iou_threshold: Intersection-over-union threshold used in Non-maximal
suppression in post-processing
num_classes: number of classes in SSD detector
scale_values: scale values is a dict with following key-value pairs
{y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode
centersize boxes
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of Fast NMS.
Returns:
transformed_graph_def: Frozen GraphDef with postprocessing custom op
appended
TFLite_Detection_PostProcess custom op node has four outputs:
detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
locations
detection_classes: a float32 tensor of shape [1, num_boxes]
with class indices
detection_scores: a float32 tensor of shape [1, num_boxes]
with class scores
num_boxes: a float32 tensor of size 1 containing the number of detected
boxes
"""
new_output = frozen_graph_def.node.add()
new_output.op = 'TFLite_Detection_PostProcess'
new_output.name = 'TFLite_Detection_PostProcess'
new_output.attr['_output_quantized'].CopyFrom(
attr_value_pb2.AttrValue(b=True))
new_output.attr['_output_types'].list.type.extend([
types_pb2.DT_FLOAT, types_pb2.DT_FLOAT, types_pb2.DT_FLOAT,
types_pb2.DT_FLOAT
])
new_output.attr['_support_output_type_float_in_quantized_op'].CopyFrom(
attr_value_pb2.AttrValue(b=True))
new_output.attr['max_detections'].CopyFrom(
attr_value_pb2.AttrValue(i=max_detections))
new_output.attr['max_classes_per_detection'].CopyFrom(
attr_value_pb2.AttrValue(i=max_classes_per_detection))
new_output.attr['nms_score_threshold'].CopyFrom(
attr_value_pb2.AttrValue(f=nms_score_threshold.pop()))
new_output.attr['nms_iou_threshold'].CopyFrom(
attr_value_pb2.AttrValue(f=nms_iou_threshold.pop()))
new_output.attr['num_classes'].CopyFrom(
attr_value_pb2.AttrValue(i=num_classes))
new_output.attr['y_scale'].CopyFrom(
attr_value_pb2.AttrValue(f=scale_values['y_scale'].pop()))
new_output.attr['x_scale'].CopyFrom(
attr_value_pb2.AttrValue(f=scale_values['x_scale'].pop()))
new_output.attr['h_scale'].CopyFrom(
attr_value_pb2.AttrValue(f=scale_values['h_scale'].pop()))
new_output.attr['w_scale'].CopyFrom(
attr_value_pb2.AttrValue(f=scale_values['w_scale'].pop()))
new_output.attr['detections_per_class'].CopyFrom(
attr_value_pb2.AttrValue(i=detections_per_class))
new_output.attr['use_regular_nms'].CopyFrom(
attr_value_pb2.AttrValue(b=use_regular_nms))
new_output.input.extend(
['raw_outputs/box_encodings', 'raw_outputs/class_predictions', 'anchors'])
# Transform the graph to append new postprocessing op
input_names = []
output_names = ['TFLite_Detection_PostProcess']
transforms = ['strip_unused_nodes']
transformed_graph_def = TransformGraph(frozen_graph_def, input_names,
output_names, transforms)
return transformed_graph_def
def export_tflite_graph(pipeline_config,
trained_checkpoint_prefix,
output_dir,
add_postprocessing_op,
max_detections,
max_classes_per_detection,
detections_per_class=100,
use_regular_nms=False,
binary_graph_name='tflite_graph.pb',
txt_graph_name='tflite_graph.pbtxt'):
"""Exports a tflite compatible graph and anchors for ssd detection model.
Anchors are written to a tensor and tflite compatible graph
is written to output_dir/tflite_graph.pb.
Args:
pipeline_config: Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`.
Value are the corresponding config objects.
trained_checkpoint_prefix: a file prefix for the checkpoint containing the
trained parameters of the SSD model.
output_dir: A directory to write the tflite graph and anchor file to.
add_postprocessing_op: If add_postprocessing_op is true: frozen graph adds a
TFLite_Detection_PostProcess custom op
max_detections: Maximum number of detections (boxes) to show
max_classes_per_detection: Number of classes to display per detection
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of Fast NMS.
binary_graph_name: Name of the exported graph file in binary format.
txt_graph_name: Name of the exported graph file in text format.
Raises:
ValueError: if the pipeline config contains models other than ssd or uses an
fixed_shape_resizer and provides a shape as well.
"""
model_config = pipeline_config['model']
lstm_config = pipeline_config['lstm_model']
eval_config = pipeline_config['eval_config']
tf.gfile.MakeDirs(output_dir)
if model_config.WhichOneof('model') != 'ssd':
raise ValueError('Only ssd models are supported in tflite. '
'Found {} in config'.format(
model_config.WhichOneof('model')))
num_classes = model_config.ssd.num_classes
nms_score_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression.
score_threshold
}
nms_iou_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression.
iou_threshold
}
scale_values = {}
scale_values['y_scale'] = {
model_config.ssd.box_coder.faster_rcnn_box_coder.y_scale
}
scale_values['x_scale'] = {
model_config.ssd.box_coder.faster_rcnn_box_coder.x_scale
}
scale_values['h_scale'] = {
model_config.ssd.box_coder.faster_rcnn_box_coder.height_scale
}
scale_values['w_scale'] = {
model_config.ssd.box_coder.faster_rcnn_box_coder.width_scale
}
image_resizer_config = model_config.ssd.image_resizer
image_resizer = image_resizer_config.WhichOneof('image_resizer_oneof')
num_channels = _DEFAULT_NUM_CHANNELS
if image_resizer == 'fixed_shape_resizer':
height = image_resizer_config.fixed_shape_resizer.height
width = image_resizer_config.fixed_shape_resizer.width
if image_resizer_config.fixed_shape_resizer.convert_to_grayscale:
num_channels = 1
#TODO(richardbrks) figure out how to make with a None defined batch size
shape = [lstm_config.eval_unroll_length, height, width, num_channels]
else:
raise ValueError(
'Only fixed_shape_resizer'
'is supported with tflite. Found {}'.format(
image_resizer_config.WhichOneof('image_resizer_oneof')))
video_tensor = tf.placeholder(
tf.float32, shape=shape, name='input_video_tensor')
detection_model = model_builder.build(model_config, lstm_config,
is_training=False)
preprocessed_video, true_image_shapes = detection_model.preprocess(
tf.to_float(video_tensor))
predicted_tensors = detection_model.predict(preprocessed_video,
true_image_shapes)
# predicted_tensors = detection_model.postprocess(predicted_tensors,
# true_image_shapes)
# The score conversion occurs before the post-processing custom op
_, score_conversion_fn = post_processing_builder.build(
model_config.ssd.post_processing)
class_predictions = score_conversion_fn(
predicted_tensors['class_predictions_with_background'])
with tf.name_scope('raw_outputs'):
# 'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
# containing the encoded box predictions. Note that these are raw
# predictions and no Non-Max suppression is applied on them and
# no decode center size boxes is applied to them.
tf.identity(predicted_tensors['box_encodings'], name='box_encodings')
# 'raw_outputs/class_predictions': a float32 tensor of shape
# [1, num_anchors, num_classes] containing the class scores for each anchor
# after applying score conversion.
tf.identity(class_predictions, name='class_predictions')
# 'anchors': a float32 tensor of shape
# [4, num_anchors] containing the anchors as a constant node.
tf.identity(
get_const_center_size_encoded_anchors(predicted_tensors['anchors']),
name='anchors')
# Add global step to the graph, so we know the training step number when we
# evaluate the model.
tf.train.get_or_create_global_step()
# graph rewriter
is_quantized = ('graph_rewriter' in pipeline_config)
if is_quantized:
graph_rewriter_config = pipeline_config['graph_rewriter']
graph_rewriter_fn = graph_rewriter_builder.build(
graph_rewriter_config, is_training=False, is_export=True)
graph_rewriter_fn()
if model_config.ssd.feature_extractor.HasField('fpn'):
exporter.rewrite_nn_resize_op(is_quantized)
# freeze the graph
saver_kwargs = {}
if eval_config.use_moving_averages:
saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
moving_average_checkpoint = tempfile.NamedTemporaryFile()
exporter.replace_variable_values_with_moving_averages(
tf.get_default_graph(), trained_checkpoint_prefix,
moving_average_checkpoint.name)
checkpoint_to_use = moving_average_checkpoint.name
else:
checkpoint_to_use = trained_checkpoint_prefix
saver = tf.train.Saver(**saver_kwargs)
input_saver_def = saver.as_saver_def()
frozen_graph_def = exporter.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=input_saver_def,
input_checkpoint=checkpoint_to_use,
output_node_names=','.join([
'raw_outputs/box_encodings', 'raw_outputs/class_predictions',
'anchors'
]),
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
clear_devices=True,
output_graph='',
initializer_nodes='')
# Add new operation to do post processing in a custom op (TF Lite only)
#(richardbrks) Do use this or detection_model.postprocess?
if add_postprocessing_op:
transformed_graph_def = append_postprocessing_op(
frozen_graph_def, max_detections, max_classes_per_detection,
nms_score_threshold, nms_iou_threshold, num_classes, scale_values,
detections_per_class, use_regular_nms)
else:
# Return frozen without adding post-processing custom op
transformed_graph_def = frozen_graph_def
binary_graph = os.path.join(output_dir, binary_graph_name)
with tf.gfile.GFile(binary_graph, 'wb') as f:
f.write(transformed_graph_def.SerializeToString())
txt_graph = os.path.join(output_dir, txt_graph_name)
with tf.gfile.GFile(txt_graph, 'w') as f:
f.write(str(transformed_graph_def))
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
from absl import flags
import tensorflow as tf
from lstm_object_detection.utils import config_util
flags.DEFINE_string('export_path', None, 'Path to export model.')
flags.DEFINE_string('frozen_graph_path', None, 'Path to frozen graph.')
flags.DEFINE_string(
'pipeline_config_path', '',
'Path to a pipeline_pb2.TrainEvalPipelineConfig config file.')
FLAGS = flags.FLAGS
def main(_):
flags.mark_flag_as_required('export_path')
flags.mark_flag_as_required('frozen_graph_path')
flags.mark_flag_as_required('pipeline_config_path')
configs = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path)
lstm_config = configs['lstm_model']
input_arrays = ['input_video_tensor']
output_arrays = [
'TFLite_Detection_PostProcess',
'TFLite_Detection_PostProcess:1',
'TFLite_Detection_PostProcess:2',
'TFLite_Detection_PostProcess:3',
]
input_shapes = {
'input_video_tensor': [lstm_config.eval_unroll_length, 320, 320, 3],
}
converter = tf.lite.TFLiteConverter.from_frozen_graph(
FLAGS.frozen_graph_path, input_arrays, output_arrays,
input_shapes=input_shapes
)
converter.allow_custom_ops = True
tflite_model = converter.convert()
ofilename = os.path.join(FLAGS.export_path)
open(ofilename, "wb").write(tflite_model)
if __name__ == '__main__':
tf.app.run()
# Exporting a tflite model from a checkpoint
Starting from a trained model checkpoint, creating a tflite model requires 2 steps:
* exporting a tflite frozen graph from a checkpoint
* exporting a tflite model from a frozen graph
## Exporting a tflite frozen graph from a checkpoint
With a candidate checkpoint to export, run the following command from
tensorflow/models/research:
```bash
# from tensorflow/models/research
PIPELINE_CONFIG_PATH={path to pipeline config}
TRAINED_CKPT_PREFIX=/{path to model.ckpt}
EXPORT_DIR={path to folder that will be used for export}
python lstm_object_detection/export_tflite_lstd_graph.py \
--pipeline_config_path ${PIPELINE_CONFIG_PATH} \
--trained_checkpoint_prefix ${TRAINED_CKPT_PREFIX} \
--output_directory ${EXPORT_DIR} \
--add_preprocessing_op
```
After export, you should see the directory ${EXPORT_DIR} containing the following files:
* `tflite_graph.pb`
* `tflite_graph.pbtxt`
## Exporting a tflite model from a frozen graph
We then take the exported tflite-compatable tflite model, and convert it to a
TFLite FlatBuffer file by running the following:
```bash
# from tensorflow/models/research
FROZEN_GRAPH_PATH={path to exported tflite_graph.pb}
EXPORT_PATH={path to filename that will be used for export}
PIPELINE_CONFIG_PATH={path to pipeline config}
python lstm_object_detection/export_tflite_lstd_model.py \
--export_path ${EXPORT_PATH} \
--frozen_graph_path ${FROZEN_GRAPH_PATH} \
--pipeline_config_path ${PIPELINE_CONFIG_PATH}
```
After export, you should see the file ${EXPORT_PATH} containing the FlatBuffer
model to be used by an application.
\ No newline at end of file
......@@ -33,68 +33,6 @@ from object_detection.protos import preprocessor_pb2
class DatasetBuilderTest(tf.test.TestCase):
def _create_tf_record(self):
path = os.path.join(self.get_temp_dir(), 'tfrecord')
writer = tf.python_io.TFRecordWriter(path)
image_tensor = np.random.randint(255, size=(16, 16, 3)).astype(np.uint8)
with self.test_session():
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
sequence_example = example_pb2.SequenceExample(
context=feature_pb2.Features(
feature={
'image/format':
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
value=['jpeg'.encode('utf-8')])),
'image/height':
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[16])),
'image/width':
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[16])),
}),
feature_lists=feature_pb2.FeatureLists(
feature_list={
'image/encoded':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
value=[encoded_jpeg])),
]),
'image/object/bbox/xmin':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[0.0])),
]),
'image/object/bbox/xmax':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[1.0]))
]),
'image/object/bbox/ymin':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[0.0])),
]),
'image/object/bbox/ymax':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[1.0]))
]),
'image/object/class/label':
feature_pb2.FeatureList(feature=[
feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[2]))
]),
}))
writer.write(sequence_example.SerializeToString())
writer.close()
return path
def _get_model_configs_from_proto(self):
"""Creates a model text proto for testing.
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl import flags
import numpy as np
import tensorflow as tf
flags.DEFINE_string('model_path', None, 'Path to model.')
FLAGS = flags.FLAGS
def main(_):
flags.mark_flag_as_required('model_path')
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=FLAGS.model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
print 'input_details:', input_details
output_details = interpreter.get_output_details()
print 'output_details:', output_details
# Test model on random input data.
input_shape = input_details[0]['shape']
# change the following line to feed into your own data.
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print output_data
if __name__ == '__main__':
tf.app.run()
......@@ -37,7 +37,7 @@ def get_configs_from_pipeline_file(pipeline_config_path):
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_confg`.
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`.
Value are the corresponding config objects.
"""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment