Commit 3858c82b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #7471 from tfboyd:higher_threshold

PiperOrigin-RevId: 264394112
parents a53371eb bee1bbc2
...@@ -117,7 +117,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -117,7 +117,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
""" """
self._run_and_report_benchmark(hr_at_10_min=0.61) self._run_and_report_benchmark(hr_at_10_min=0.61)
def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.640): def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.645):
"""Run test and report results. """Run test and report results.
Note: Target is 0.635, but some runs are below that level. Until we have Note: Target is 0.635, but some runs are below that level. Until we have
......
...@@ -41,8 +41,8 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark): ...@@ -41,8 +41,8 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark):
flag_methods=[shakespeare_main.define_flags]) flag_methods=[shakespeare_main.define_flags])
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
top_1_train_min=0.923, top_1_train_min=0.91,
top_1_train_max=0.93, top_1_train_max=0.94,
warmup=1, warmup=1,
log_steps=100): log_steps=100):
"""Report benchmark results by writing to local protobuf file. """Report benchmark results by writing to local protobuf file.
......
...@@ -280,8 +280,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -280,8 +280,8 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size, self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
bleu_min=28, bleu_min=27.9,
bleu_max=29) bleu_max=29.2)
def benchmark_8_gpu_static_batch(self): def benchmark_8_gpu_static_batch(self):
"""Benchmark 8 gpu. """Benchmark 8 gpu.
...@@ -305,12 +305,19 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -305,12 +305,19 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size, self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
bleu_min=28, bleu_min=28,
bleu_max=29) bleu_max=29.2)
def benchmark_8_gpu_fp16(self): def benchmark_8_gpu_fp16(self):
"""Benchmark 8 gpu with dynamic batch and fp16. """Benchmark 8 gpu with dynamic batch and fp16.
Should converge to 28.4 BLEU (uncased). This has not be verified yet." Over 6 runs with eval every 20K steps the average highest value was 28.247
(bleu uncased). 28.424 was the highest and 28.09 the lowest. The values are
the highest value seen during a run and occurred at a median of iteration
11. While this could be interpreted as worse than FP32, if looking at the
first iteration at which 28 is passed FP16 performs equal and possibly
better. Although not part of the initial test runs, the highest value
recorded with the arguments below was 28.9 at iteration 12. Iterations are
not epochs, an iteration is a number of steps between evals.
""" """
self._setup() self._setup()
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
...@@ -328,7 +335,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -328,7 +335,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size, self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
bleu_min=28, bleu_min=28,
bleu_max=29) bleu_max=29.2)
def benchmark_8_gpu_static_batch_fp16(self): def benchmark_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch and fp16. """Benchmark 8 gpu with static batch and fp16.
...@@ -353,7 +360,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -353,7 +360,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size, self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
bleu_min=28, bleu_min=28,
bleu_max=29) bleu_max=29.2)
def benchmark_xla_8_gpu_static_batch_fp16(self): def benchmark_xla_8_gpu_static_batch_fp16(self):
"""Benchmark 8 gpu with static batch, XLA, and FP16. """Benchmark 8 gpu with static batch, XLA, and FP16.
...@@ -380,7 +387,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark): ...@@ -380,7 +387,7 @@ class TransformerBigKerasAccuracy(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size, self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps, log_steps=FLAGS.log_steps,
bleu_min=28, bleu_min=28,
bleu_max=29) bleu_max=29.2)
class TransformerKerasBenchmark(TransformerBenchmark): class TransformerKerasBenchmark(TransformerBenchmark):
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
r"""Exports an LSTM detection model to use with tf-lite. r"""Exports an LSTM detection model to use with tf-lite.
Outputs file: Outputs file:
...@@ -85,9 +86,8 @@ python lstm_object_detection/export_tflite_lstd_graph.py \ ...@@ -85,9 +86,8 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
""" """
import tensorflow as tf import tensorflow as tf
from lstm_object_detection import export_tflite_lstd_graph_lib
from lstm_object_detection.utils import config_util from lstm_object_detection.utils import config_util
from lstm_object_detection import export_tflite_lstd_graph_lib
flags = tf.app.flags flags = tf.app.flags
flags.DEFINE_string('output_directory', None, 'Path to write outputs.') flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
...@@ -122,16 +122,12 @@ def main(argv): ...@@ -122,16 +122,12 @@ def main(argv):
flags.mark_flag_as_required('trained_checkpoint_prefix') flags.mark_flag_as_required('trained_checkpoint_prefix')
pipeline_config = config_util.get_configs_from_pipeline_file( pipeline_config = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path) FLAGS.pipeline_config_path)
export_tflite_lstd_graph_lib.export_tflite_graph( export_tflite_lstd_graph_lib.export_tflite_graph(
pipeline_config, pipeline_config, FLAGS.trained_checkpoint_prefix, FLAGS.output_directory,
FLAGS.trained_checkpoint_prefix, FLAGS.add_postprocessing_op, FLAGS.max_detections,
FLAGS.output_directory, FLAGS.max_classes_per_detection, use_regular_nms=FLAGS.use_regular_nms)
FLAGS.add_postprocessing_op,
FLAGS.max_detections,
FLAGS.max_classes_per_detection,
use_regular_nms=FLAGS.use_regular_nms)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -12,26 +12,26 @@ ...@@ -12,26 +12,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
r"""Exports detection models to use with tf-lite. r"""Exports detection models to use with tf-lite.
See export_tflite_lstd_graph.py for usage. See export_tflite_lstd_graph.py for usage.
""" """
import os import os
import tempfile import tempfile
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import saver_pb2 from tensorflow.core.protobuf import saver_pb2
from tensorflow.tools.graph_transforms import TransformGraph from tensorflow.tools.graph_transforms import TransformGraph
from lstm_object_detection import model_builder
from object_detection import exporter from object_detection import exporter
from object_detection.builders import graph_rewriter_builder from object_detection.builders import graph_rewriter_builder
from object_detection.builders import post_processing_builder from object_detection.builders import post_processing_builder
from object_detection.core import box_list from object_detection.core import box_list
from lstm_object_detection import model_builder
_DEFAULT_NUM_CHANNELS = 3 _DEFAULT_NUM_CHANNELS = 3
_DEFAULT_NUM_COORD_BOX = 4 _DEFAULT_NUM_COORD_BOX = 4
...@@ -84,11 +84,11 @@ def append_postprocessing_op(frozen_graph_def, ...@@ -84,11 +84,11 @@ def append_postprocessing_op(frozen_graph_def,
num_classes: number of classes in SSD detector num_classes: number of classes in SSD detector
scale_values: scale values is a dict with following key-value pairs scale_values: scale values is a dict with following key-value pairs
{y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode {y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode
centersize boxes centersize boxes
detections_per_class: In regular NonMaxSuppression, number of anchors used detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
Fast NMS. of Fast NMS.
Returns: Returns:
transformed_graph_def: Frozen GraphDef with postprocessing custom op transformed_graph_def: Frozen GraphDef with postprocessing custom op
...@@ -165,9 +165,9 @@ def export_tflite_graph(pipeline_config, ...@@ -165,9 +165,9 @@ def export_tflite_graph(pipeline_config,
is written to output_dir/tflite_graph.pb. is written to output_dir/tflite_graph.pb.
Args: Args:
pipeline_config: Dictionary of configuration objects. Keys are `model`, pipeline_config: Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_config`, `train_input_config`, `eval_config`, `eval_input_config`, `train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`.
`lstm_model`. Value are the corresponding config objects. Value are the corresponding config objects.
trained_checkpoint_prefix: a file prefix for the checkpoint containing the trained_checkpoint_prefix: a file prefix for the checkpoint containing the
trained parameters of the SSD model. trained parameters of the SSD model.
output_dir: A directory to write the tflite graph and anchor file to. output_dir: A directory to write the tflite graph and anchor file to.
...@@ -176,9 +176,9 @@ def export_tflite_graph(pipeline_config, ...@@ -176,9 +176,9 @@ def export_tflite_graph(pipeline_config,
max_detections: Maximum number of detections (boxes) to show max_detections: Maximum number of detections (boxes) to show
max_classes_per_detection: Number of classes to display per detection max_classes_per_detection: Number of classes to display per detection
detections_per_class: In regular NonMaxSuppression, number of anchors used detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
Fast NMS. of Fast NMS.
binary_graph_name: Name of the exported graph file in binary format. binary_graph_name: Name of the exported graph file in binary format.
txt_graph_name: Name of the exported graph file in text format. txt_graph_name: Name of the exported graph file in text format.
...@@ -197,10 +197,12 @@ def export_tflite_graph(pipeline_config, ...@@ -197,10 +197,12 @@ def export_tflite_graph(pipeline_config,
num_classes = model_config.ssd.num_classes num_classes = model_config.ssd.num_classes
nms_score_threshold = { nms_score_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression.score_threshold model_config.ssd.post_processing.batch_non_max_suppression.
score_threshold
} }
nms_iou_threshold = { nms_iou_threshold = {
model_config.ssd.post_processing.batch_non_max_suppression.iou_threshold model_config.ssd.post_processing.batch_non_max_suppression.
iou_threshold
} }
scale_values = {} scale_values = {}
scale_values['y_scale'] = { scale_values['y_scale'] = {
...@@ -224,7 +226,7 @@ def export_tflite_graph(pipeline_config, ...@@ -224,7 +226,7 @@ def export_tflite_graph(pipeline_config,
width = image_resizer_config.fixed_shape_resizer.width width = image_resizer_config.fixed_shape_resizer.width
if image_resizer_config.fixed_shape_resizer.convert_to_grayscale: if image_resizer_config.fixed_shape_resizer.convert_to_grayscale:
num_channels = 1 num_channels = 1
#TODO(richardbrks) figure out how to make with a None defined batch size
shape = [lstm_config.eval_unroll_length, height, width, num_channels] shape = [lstm_config.eval_unroll_length, height, width, num_channels]
else: else:
raise ValueError( raise ValueError(
...@@ -233,14 +235,14 @@ def export_tflite_graph(pipeline_config, ...@@ -233,14 +235,14 @@ def export_tflite_graph(pipeline_config,
image_resizer_config.WhichOneof('image_resizer_oneof'))) image_resizer_config.WhichOneof('image_resizer_oneof')))
video_tensor = tf.placeholder( video_tensor = tf.placeholder(
tf.float32, shape=shape, name='input_video_tensor') tf.float32, shape=shape, name='input_video_tensor')
detection_model = model_builder.build( detection_model = model_builder.build(model_config, lstm_config,
model_config, lstm_config, is_training=False) is_training=False)
preprocessed_video, true_image_shapes = detection_model.preprocess( preprocessed_video, true_image_shapes = detection_model.preprocess(
tf.to_float(video_tensor)) tf.to_float(video_tensor))
predicted_tensors = detection_model.predict(preprocessed_video, predicted_tensors = detection_model.predict(preprocessed_video,
true_image_shapes) true_image_shapes)
# predicted_tensors = detection_model.postprocess(predicted_tensors, # predicted_tensors = detection_model.postprocess(predicted_tensors,
# true_image_shapes) # true_image_shapes)
# The score conversion occurs before the post-processing custom op # The score conversion occurs before the post-processing custom op
...@@ -309,7 +311,7 @@ def export_tflite_graph(pipeline_config, ...@@ -309,7 +311,7 @@ def export_tflite_graph(pipeline_config,
initializer_nodes='') initializer_nodes='')
# Add new operation to do post processing in a custom op (TF Lite only) # Add new operation to do post processing in a custom op (TF Lite only)
#(richardbrks) Do use this or detection_model.postprocess?
if add_postprocessing_op: if add_postprocessing_op:
transformed_graph_def = append_postprocessing_op( transformed_graph_def = append_postprocessing_op(
frozen_graph_def, max_detections, max_classes_per_detection, frozen_graph_def, max_detections, max_classes_per_detection,
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Export a LSTD model in tflite format."""
import os import os
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
...@@ -31,35 +29,34 @@ FLAGS = flags.FLAGS ...@@ -31,35 +29,34 @@ FLAGS = flags.FLAGS
def main(_): def main(_):
flags.mark_flag_as_required('export_path') flags.mark_flag_as_required('export_path')
flags.mark_flag_as_required('frozen_graph_path') flags.mark_flag_as_required('frozen_graph_path')
flags.mark_flag_as_required('pipeline_config_path') flags.mark_flag_as_required('pipeline_config_path')
configs = config_util.get_configs_from_pipeline_file( configs = config_util.get_configs_from_pipeline_file(
FLAGS.pipeline_config_path) FLAGS.pipeline_config_path)
lstm_config = configs['lstm_model'] lstm_config = configs['lstm_model']
input_arrays = ['input_video_tensor'] input_arrays = ['input_video_tensor']
output_arrays = [ output_arrays = [
'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess',
'TFLite_Detection_PostProcess:1', 'TFLite_Detection_PostProcess:1',
'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:2',
'TFLite_Detection_PostProcess:3', 'TFLite_Detection_PostProcess:3',
] ]
input_shapes = { input_shapes = {
'input_video_tensor': [lstm_config.eval_unroll_length, 320, 320, 3], 'input_video_tensor': [lstm_config.eval_unroll_length, 320, 320, 3],
} }
converter = tf.lite.TFLiteConverter.from_frozen_graph( converter = tf.lite.TFLiteConverter.from_frozen_graph(
FLAGS.frozen_graph_path, FLAGS.frozen_graph_path, input_arrays, output_arrays,
input_arrays, input_shapes=input_shapes
output_arrays, )
input_shapes=input_shapes) converter.allow_custom_ops = True
converter.allow_custom_ops = True tflite_model = converter.convert()
tflite_model = converter.convert() ofilename = os.path.join(FLAGS.export_path)
ofilename = os.path.join(FLAGS.export_path) open(ofilename, "wb").write(tflite_model)
open(ofilename, 'wb').write(tflite_model)
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() tf.app.run()
# Exporting a tflite model from a checkpoint # Exporting a tflite model from a checkpoint
Starting from a trained model checkpoint, creating a tflite model requires 2 Starting from a trained model checkpoint, creating a tflite model requires 2 steps:
steps:
* exporting a tflite frozen graph from a checkpoint
* exporting a tflite model from a frozen graph
* exporting a tflite frozen graph from a checkpoint
* exporting a tflite model from a frozen graph
## Exporting a tflite frozen graph from a checkpoint ## Exporting a tflite frozen graph from a checkpoint
...@@ -20,14 +20,14 @@ python lstm_object_detection/export_tflite_lstd_graph.py \ ...@@ -20,14 +20,14 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
--pipeline_config_path ${PIPELINE_CONFIG_PATH} \ --pipeline_config_path ${PIPELINE_CONFIG_PATH} \
--trained_checkpoint_prefix ${TRAINED_CKPT_PREFIX} \ --trained_checkpoint_prefix ${TRAINED_CKPT_PREFIX} \
--output_directory ${EXPORT_DIR} \ --output_directory ${EXPORT_DIR} \
--add_preprocessing_op --add_preprocessing_op
``` ```
After export, you should see the directory ${EXPORT_DIR} containing the After export, you should see the directory ${EXPORT_DIR} containing the following files:
following files:
* `tflite_graph.pb`
* `tflite_graph.pbtxt`
* `tflite_graph.pb`
* `tflite_graph.pbtxt`
## Exporting a tflite model from a frozen graph ## Exporting a tflite model from a frozen graph
...@@ -40,10 +40,10 @@ FROZEN_GRAPH_PATH={path to exported tflite_graph.pb} ...@@ -40,10 +40,10 @@ FROZEN_GRAPH_PATH={path to exported tflite_graph.pb}
EXPORT_PATH={path to filename that will be used for export} EXPORT_PATH={path to filename that will be used for export}
PIPELINE_CONFIG_PATH={path to pipeline config} PIPELINE_CONFIG_PATH={path to pipeline config}
python lstm_object_detection/export_tflite_lstd_model.py \ python lstm_object_detection/export_tflite_lstd_model.py \
--export_path ${EXPORT_PATH} \ --export_path ${EXPORT_PATH} \
--frozen_graph_path ${FROZEN_GRAPH_PATH} \ --frozen_graph_path ${FROZEN_GRAPH_PATH} \
--pipeline_config_path ${PIPELINE_CONFIG_PATH} --pipeline_config_path ${PIPELINE_CONFIG_PATH}
``` ```
After export, you should see the file ${EXPORT_PATH} containing the FlatBuffer After export, you should see the file ${EXPORT_PATH} containing the FlatBuffer
model to be used by an application. model to be used by an application.
\ No newline at end of file
...@@ -13,9 +13,6 @@ ...@@ -13,9 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Test a tflite model using random input data."""
from __future__ import print_function
from absl import flags from absl import flags
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -26,28 +23,28 @@ FLAGS = flags.FLAGS ...@@ -26,28 +23,28 @@ FLAGS = flags.FLAGS
def main(_): def main(_):
flags.mark_flag_as_required('model_path') flags.mark_flag_as_required('model_path')
# Load TFLite model and allocate tensors. # Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=FLAGS.model_path) interpreter = tf.lite.Interpreter(model_path=FLAGS.model_path)
interpreter.allocate_tensors() interpreter.allocate_tensors()
# Get input and output tensors. # Get input and output tensors.
input_details = interpreter.get_input_details() input_details = interpreter.get_input_details()
print('input_details:', input_details) print 'input_details:', input_details
output_details = interpreter.get_output_details() output_details = interpreter.get_output_details()
print('output_details:', output_details) print 'output_details:', output_details
# Test model on random input data. # Test model on random input data.
input_shape = input_details[0]['shape'] input_shape = input_details[0]['shape']
# change the following line to feed into your own data. # change the following line to feed into your own data.
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data) interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke() interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index']) output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data) print output_data
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() tf.app.run()
...@@ -59,19 +59,12 @@ cc_library( ...@@ -59,19 +59,12 @@ cc_library(
name = "mobile_lstd_tflite_client", name = "mobile_lstd_tflite_client",
srcs = ["mobile_lstd_tflite_client.cc"], srcs = ["mobile_lstd_tflite_client.cc"],
hdrs = ["mobile_lstd_tflite_client.h"], hdrs = ["mobile_lstd_tflite_client.h"],
defines = select({
"//conditions:default": [],
"enable_edgetpu": ["ENABLE_EDGETPU"],
}),
deps = [ deps = [
":mobile_ssd_client", ":mobile_ssd_client",
":mobile_ssd_tflite_client", ":mobile_ssd_tflite_client",
"@com_google_glog//:glog",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_glog//:glog",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
] + select({ ],
"//conditions:default": [],
"enable_edgetpu": ["@libedgetpu//libedgetpu:header"],
}),
alwayslink = 1, alwayslink = 1,
) )
...@@ -90,6 +90,13 @@ http_archive( ...@@ -90,6 +90,13 @@ http_archive(
sha256 = "79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc", sha256 = "79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc",
) )
#
# http_archive(
# name = "com_google_protobuf",
# strip_prefix = "protobuf-master",
# urls = ["https://github.com/protocolbuffers/protobuf/archive/master.zip"],
# )
# Needed by TensorFlow # Needed by TensorFlow
http_archive( http_archive(
name = "io_bazel_rules_closure", name = "io_bazel_rules_closure",
......
...@@ -66,11 +66,6 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter( ...@@ -66,11 +66,6 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
interpreter_->UseNNAPI(false); interpreter_->UseNNAPI(false);
} }
#ifdef ENABLE_EDGETPU
interpreter_->SetExternalContext(kTfLiteEdgeTpuContext,
edge_tpu_context_.get());
#endif
// Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c, // Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
// raw_inputs/init_lstm_h // raw_inputs/init_lstm_h
if (interpreter_->inputs().size() != 3) { if (interpreter_->inputs().size() != 3) {
......
...@@ -76,10 +76,6 @@ class MobileSSDTfLiteClient : public MobileSSDClient { ...@@ -76,10 +76,6 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
std::unique_ptr<::tflite::MutableOpResolver> resolver_; std::unique_ptr<::tflite::MutableOpResolver> resolver_;
std::unique_ptr<::tflite::Interpreter> interpreter_; std::unique_ptr<::tflite::Interpreter> interpreter_;
#ifdef ENABLE_EDGETPU
std::unique_ptr<edgetpu::EdgeTpuContext> edge_tpu_context_;
#endif
private: private:
// MobileSSDTfLiteClient is neither copyable nor movable. // MobileSSDTfLiteClient is neither copyable nor movable.
MobileSSDTfLiteClient(const MobileSSDTfLiteClient&) = delete; MobileSSDTfLiteClient(const MobileSSDTfLiteClient&) = delete;
...@@ -107,6 +103,10 @@ class MobileSSDTfLiteClient : public MobileSSDClient { ...@@ -107,6 +103,10 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
bool FloatInference(const uint8_t* input_data); bool FloatInference(const uint8_t* input_data);
bool QuantizedInference(const uint8_t* input_data); bool QuantizedInference(const uint8_t* input_data);
void GetOutputBoxesAndScoreTensorsFromUInt8(); void GetOutputBoxesAndScoreTensorsFromUInt8();
#ifdef ENABLE_EDGETPU
std::unique_ptr<edgetpu::EdgeTpuContext> edge_tpu_context_;
#endif
}; };
} // namespace tflite } // namespace tflite
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment