Commit 886d7bc9 authored by Yongzhe Wang's avatar Yongzhe Wang Committed by Hongkun Yu
Browse files

Merged commit includes the following changes: (#7470)

* Merged commit includes the following changes:
263863588  by yongzhe:

    Fix a bug that the SetExternalContext for EdgeTPU wasn't called when initializing LSTD client.

--
263370193  by yongzhe:

    Internal change.

--

PiperOrigin-RevId: 263863588

* Revert changes in seq_dataset_builder_test.py
parent ee584397
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exports an LSTM detection model to use with tf-lite.
Outputs file:
......@@ -86,8 +85,9 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
"""
import tensorflow as tf
from lstm_object_detection.utils import config_util
from lstm_object_detection import export_tflite_lstd_graph_lib
from lstm_object_detection.utils import config_util
flags = tf.app.flags
flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
......@@ -125,9 +125,13 @@ def main(argv):
FLAGS.pipeline_config_path)
export_tflite_lstd_graph_lib.export_tflite_graph(
pipeline_config, FLAGS.trained_checkpoint_prefix, FLAGS.output_directory,
FLAGS.add_postprocessing_op, FLAGS.max_detections,
FLAGS.max_classes_per_detection, use_regular_nms=FLAGS.use_regular_nms)
pipeline_config,
FLAGS.trained_checkpoint_prefix,
FLAGS.output_directory,
FLAGS.add_postprocessing_op,
FLAGS.max_detections,
FLAGS.max_classes_per_detection,
use_regular_nms=FLAGS.use_regular_nms)
if __name__ == '__main__':
......
......@@ -12,26 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exports detection models to use with tf-lite.
See export_tflite_lstd_graph.py for usage.
"""
import os
import tempfile
import numpy as np
import tensorflow as tf
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from lstm_object_detection import model_builder
from object_detection import exporter
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import post_processing_builder
from object_detection.core import box_list
from lstm_object_detection import model_builder
_DEFAULT_NUM_CHANNELS = 3
_DEFAULT_NUM_COORD_BOX = 4
......@@ -87,8 +87,8 @@ def append_postprocessing_op(frozen_graph_def,
centersize boxes
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of Fast NMS.
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
Fast NMS.
Returns:
transformed_graph_def: Frozen GraphDef with postprocessing custom op
......@@ -165,9 +165,9 @@ def export_tflite_graph(pipeline_config,
is written to output_dir/tflite_graph.pb.
Args:
pipeline_config: Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`.
Value are the corresponding config objects.
pipeline_config: Dictionary of configuration objects. Keys are `model`,
`train_config`, `train_input_config`, `eval_config`, `eval_input_config`,
`lstm_model`. Value are the corresponding config objects.
trained_checkpoint_prefix: a file prefix for the checkpoint containing the
trained parameters of the SSD model.
output_dir: A directory to write the tflite graph and anchor file to.
......@@ -177,8 +177,8 @@ def export_tflite_graph(pipeline_config,
max_classes_per_detection: Number of classes to display per detection
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of Fast NMS.
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
Fast NMS.
binary_graph_name: Name of the exported graph file in binary format.
txt_graph_name: Name of the exported graph file in text format.
......@@ -197,12 +197,10 @@ def export_tflite_graph(pipeline_config,
num_classes = model_config.ssd.num_classes
nms_score_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression.
score_threshold
model_config.ssd.post_processing.batch_non_max_suppression.score_threshold
}
nms_iou_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression.
iou_threshold
model_config.ssd.post_processing.batch_non_max_suppression.iou_threshold
}
scale_values = {}
scale_values['y_scale'] = {
......@@ -226,7 +224,7 @@ def export_tflite_graph(pipeline_config,
width = image_resizer_config.fixed_shape_resizer.width
if image_resizer_config.fixed_shape_resizer.convert_to_grayscale:
num_channels = 1
#TODO(richardbrks) figure out how to make with a None defined batch size
shape = [lstm_config.eval_unroll_length, height, width, num_channels]
else:
raise ValueError(
......@@ -237,8 +235,8 @@ def export_tflite_graph(pipeline_config,
video_tensor = tf.placeholder(
tf.float32, shape=shape, name='input_video_tensor')
detection_model = model_builder.build(model_config, lstm_config,
is_training=False)
detection_model = model_builder.build(
model_config, lstm_config, is_training=False)
preprocessed_video, true_image_shapes = detection_model.preprocess(
tf.to_float(video_tensor))
predicted_tensors = detection_model.predict(preprocessed_video,
......@@ -311,7 +309,7 @@ def export_tflite_graph(pipeline_config,
initializer_nodes='')
# Add new operation to do post processing in a custom op (TF Lite only)
#(richardbrks) Do use this or detection_model.postprocess?
if add_postprocessing_op:
transformed_graph_def = append_postprocessing_op(
frozen_graph_def, max_detections, max_classes_per_detection,
......
......@@ -13,6 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Export a LSTD model in tflite format."""
import os
from absl import flags
import tensorflow as tf
......@@ -49,13 +51,14 @@ def main(_):
}
converter = tf.lite.TFLiteConverter.from_frozen_graph(
FLAGS.frozen_graph_path, input_arrays, output_arrays,
input_shapes=input_shapes
)
FLAGS.frozen_graph_path,
input_arrays,
output_arrays,
input_shapes=input_shapes)
converter.allow_custom_ops = True
tflite_model = converter.convert()
ofilename = os.path.join(FLAGS.export_path)
open(ofilename, "wb").write(tflite_model)
open(ofilename, 'wb').write(tflite_model)
if __name__ == '__main__':
......
# Exporting a tflite model from a checkpoint
Starting from a trained model checkpoint, creating a tflite model requires 2 steps:
Starting from a trained model checkpoint, creating a tflite model requires 2
steps:
* exporting a tflite frozen graph from a checkpoint
* exporting a tflite model from a frozen graph
## Exporting a tflite frozen graph from a checkpoint
With a candidate checkpoint to export, run the following command from
......@@ -23,12 +23,12 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
--add_preprocessing_op
```
After export, you should see the directory ${EXPORT_DIR} containing the following files:
After export, you should see the directory ${EXPORT_DIR} containing the
following files:
* `tflite_graph.pb`
* `tflite_graph.pbtxt`
## Exporting a tflite model from a frozen graph
We then take the exported tflite-compatable tflite model, and convert it to a
......
......@@ -13,6 +13,9 @@
# limitations under the License.
# ==============================================================================
"""Test a tflite model using random input data."""
from __future__ import print_function
from absl import flags
import numpy as np
import tensorflow as tf
......@@ -31,9 +34,9 @@ def main(_):
# Get input and output tensors.
input_details = interpreter.get_input_details()
print 'input_details:', input_details
print('input_details:', input_details)
output_details = interpreter.get_output_details()
print 'output_details:', output_details
print('output_details:', output_details)
# Test model on random input data.
input_shape = input_details[0]['shape']
......@@ -43,7 +46,7 @@ def main(_):
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print output_data
print(output_data)
if __name__ == '__main__':
......
......@@ -59,12 +59,19 @@ cc_library(
name = "mobile_lstd_tflite_client",
srcs = ["mobile_lstd_tflite_client.cc"],
hdrs = ["mobile_lstd_tflite_client.h"],
defines = select({
"//conditions:default": [],
"enable_edgetpu": ["ENABLE_EDGETPU"],
}),
deps = [
":mobile_ssd_client",
":mobile_ssd_tflite_client",
"@com_google_absl//absl/base:core_headers",
"@com_google_glog//:glog",
"@com_google_absl//absl/base:core_headers",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
] + select({
"//conditions:default": [],
"enable_edgetpu": ["@libedgetpu//libedgetpu:header"],
}),
alwayslink = 1,
)
......@@ -66,6 +66,11 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
interpreter_->UseNNAPI(false);
}
#ifdef ENABLE_EDGETPU
interpreter_->SetExternalContext(kTfLiteEdgeTpuContext,
edge_tpu_context_.get());
#endif
// Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
// raw_inputs/init_lstm_h
if (interpreter_->inputs().size() != 3) {
......
......@@ -76,6 +76,10 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
std::unique_ptr<::tflite::MutableOpResolver> resolver_;
std::unique_ptr<::tflite::Interpreter> interpreter_;
#ifdef ENABLE_EDGETPU
std::unique_ptr<edgetpu::EdgeTpuContext> edge_tpu_context_;
#endif
private:
// MobileSSDTfLiteClient is neither copyable nor movable.
MobileSSDTfLiteClient(const MobileSSDTfLiteClient&) = delete;
......@@ -103,10 +107,6 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
bool FloatInference(const uint8_t* input_data);
bool QuantizedInference(const uint8_t* input_data);
void GetOutputBoxesAndScoreTensorsFromUInt8();
#ifdef ENABLE_EDGETPU
std::unique_ptr<edgetpu::EdgeTpuContext> edge_tpu_context_;
#endif
};
} // namespace tflite
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment