Commit 66d00a87 authored by yongzhe2160's avatar yongzhe2160 Committed by Menglong Zhu
Browse files

Merged commit includes the following changes: (#7220)

* Merged commit includes the following changes:
257930561  by yongzhe:

    Mobile LSTD TfLite Client.

--
257928126  by yongzhe:

    Mobile SSD Tflite client.

--
257921181  by menglong:

    Fix discrepancy between pre_bottleneck = {true, false}

--
257561213  by yongzhe:

    File utils.

--
257449226  by yongzhe:

    Mobile SSD Client.

--
257264654  by yongzhe:

    SSD utils.

--
257235648  by yongzhe:

    Proto bazel build rules.

--
256437262  by Menglong Zhu:

    Fix check for FusedBatchNorm op to only verify it as a prefix.

--
256283755  by yongzhe:

    Bazel build and copybara changes.

--
251947295  by yinxiao:

    Add missing interleaved option in checkpoint restore.

--
251513479  by yongzhe:

    Conversion utils.

--
248783193  by yongzhe:

    Branch protos needed for the lstd client.

--
248200507  by menglong:

    Fix proto namespace in example config

--

PiperOrigin-RevId: 257930561

* Delete BUILD
parent 395f6d2d
......@@ -31,3 +31,4 @@ https://scholar.googleusercontent.com/scholar.bib?q=info:rLqvkztmWYgJ:scholar.go
* masonliuw@gmail.com
* yinxiao@google.com
* menglong@google.com
* yongzhe@google.com
......@@ -15,7 +15,7 @@
# For training on Imagenet Video with LSTM Mobilenet V1
[object_detection.protos.lstm_model] {
[lstm_object_detection.protos.lstm_model] {
train_unroll_length: 4
eval_unroll_length: 4
}
......
......@@ -439,7 +439,7 @@ class GroupedConvLSTMCell(tf.contrib.rnn.RNNCell):
bottleneck_concat = lstm_utils.quantizable_concat(
[inputs, h_list[k]],
axis=3,
is_training=False,
is_training=self._is_training,
is_quantized=self._is_quantized,
scope='bottleneck_%d/quantized_concat' % k)
......
......@@ -238,11 +238,24 @@ class LSTMSSDMetaArch(ssd_meta_arch.SSDMetaArch):
`classification`/`detection`/`interleaved`/`lstm`.
"""
if fine_tune_checkpoint_type not in [
'classification', 'detection', 'lstm'
'classification', 'detection', 'interleaved', 'lstm',
'interleaved_pretrain'
]:
raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
fine_tune_checkpoint_type))
self._restored_networks += 1
base_network_scope = self.get_base_network_scope()
if base_network_scope:
scope_to_replace = '{0}_{1}'.format(base_network_scope,
self._restored_networks)
interleaved_model = False
for variable in tf.global_variables():
if scope_to_replace in variable.op.name:
interleaved_model = True
break
variables_to_restore = {}
for variable in tf.global_variables():
var_name = variable.op.name
......@@ -250,7 +263,8 @@ class LSTMSSDMetaArch(ssd_meta_arch.SSDMetaArch):
continue
# Remove FeatureExtractor prefix for classification checkpoints.
if fine_tune_checkpoint_type == 'classification':
if (fine_tune_checkpoint_type == 'classification' or
fine_tune_checkpoint_type == 'interleaved_pretrain'):
var_name = (
re.split('^' + self._extract_features_scope + '/', var_name)[-1])
......@@ -260,6 +274,25 @@ class LSTMSSDMetaArch(ssd_meta_arch.SSDMetaArch):
fine_tune_checkpoint_type == 'detection'):
var_name = var_name.replace('FeatureMaps',
self.get_base_network_scope())
# Load interleaved checkpoint specifically.
if interleaved_model: # Interleaved LSTD.
if 'interleaved' in fine_tune_checkpoint_type:
variables_to_restore[var_name] = variable
else:
# Restore non-base layers from the first checkpoint only.
if self._restored_networks == 1:
if base_network_scope + '_' not in var_name: # LSTM and FeatureMap
variables_to_restore[var_name] = variable
if scope_to_replace in var_name:
var_name = var_name.replace(scope_to_replace, base_network_scope)
variables_to_restore[var_name] = variable
else:
# Restore from the first model of interleaved checkpoints
if 'interleaved' in fine_tune_checkpoint_type:
var_name = var_name.replace(self.get_base_network_scope(),
self.get_base_network_scope() + '_1', 1)
variables_to_restore[var_name] = variable
return variables_to_restore
......
......@@ -149,7 +149,7 @@ class LSTMSSDInterleavedMobilenetV2FeatureExtractorTest(
pad_to_multiple)
preprocessed_image = feature_extractor.preprocess(image_placeholder)
_ = feature_extractor.extract_features(preprocessed_image, unroll_length=1)
self.assertTrue(any(op.type == 'FusedBatchNorm'
self.assertTrue(any(op.type.startswith('FusedBatchNorm')
for op in tf.get_default_graph().get_operations()))
def test_variables_for_tflite(self):
......
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
cc_library(
name = "mobile_ssd_client",
srcs = ["mobile_ssd_client.cc"],
hdrs = ["mobile_ssd_client.h"],
deps = [
"//protos:box_encodings_cc_proto",
"//protos:detections_cc_proto",
"//protos:labelmap_cc_proto",
"//protos:mobile_ssd_client_options_cc_proto",
"//utils:conversion_utils",
"//utils:ssd_utils",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
"@com_google_glog//:glog",
"@gemmlowp",
],
)
cc_library(
name = "mobile_ssd_tflite_client",
srcs = ["mobile_ssd_tflite_client.cc"],
hdrs = ["mobile_ssd_tflite_client.h"],
deps = [
":mobile_ssd_client",
"//protos:anchor_generation_options_cc_proto",
"//utils:file_utils",
"//utils:ssd_utils",
"@com_google_absl//absl/memory",
"@com_google_glog//:glog",
"@org_tensorflow//tensorflow/lite:arena_planner",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
alwayslink = 1,
)
cc_library(
name = "mobile_lstd_tflite_client",
srcs = ["mobile_lstd_tflite_client.cc"],
hdrs = ["mobile_lstd_tflite_client.h"],
deps = [
":mobile_ssd_client",
":mobile_ssd_tflite_client",
"@com_google_absl//absl/base:core_headers",
"@com_google_glog//:glog",
],
alwayslink = 1,
)
workspace(name = "lstm_object_detection")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "bazel_skylib",
sha256 = "bbccf674aa441c266df9894182d80de104cabd19be98be002f6d478aaa31574d",
strip_prefix = "bazel-skylib-2169ae1c374aab4a09aa90e65efe1a3aad4e279b",
urls = ["https://github.com/bazelbuild/bazel-skylib/archive/2169ae1c374aab4a09aa90e65efe1a3aad4e279b.tar.gz"],
)
load("@bazel_skylib//lib:versions.bzl", "versions")
versions.check(minimum_bazel_version = "0.23.0")
# ABSL cpp library.
http_archive(
name = "com_google_absl",
urls = [
"https://github.com/abseil/abseil-cpp/archive/a02f62f456f2c4a7ecf2be3104fe0c6e16fbad9a.tar.gz",
],
sha256 = "d437920d1434c766d22e85773b899c77c672b8b4865d5dc2cd61a29fdff3cf03",
strip_prefix = "abseil-cpp-a02f62f456f2c4a7ecf2be3104fe0c6e16fbad9a",
)
# GoogleTest/GoogleMock framework. Used by most unit-tests.
http_archive(
name = "com_google_googletest",
urls = ["https://github.com/google/googletest/archive/master.zip"],
strip_prefix = "googletest-master",
)
# gflags needed by glog
http_archive(
name = "com_github_gflags_gflags",
sha256 = "6e16c8bc91b1310a44f3965e616383dbda48f83e8c1eaa2370a215057b00cabe",
strip_prefix = "gflags-77592648e3f3be87d6c7123eb81cbad75f9aef5a",
urls = [
"https://mirror.bazel.build/github.com/gflags/gflags/archive/77592648e3f3be87d6c7123eb81cbad75f9aef5a.tar.gz",
"https://github.com/gflags/gflags/archive/77592648e3f3be87d6c7123eb81cbad75f9aef5a.tar.gz",
],
)
# glog
http_archive(
name = "com_google_glog",
sha256 = "f28359aeba12f30d73d9e4711ef356dc842886968112162bc73002645139c39c",
strip_prefix = "glog-0.4.0",
urls = ["https://github.com/google/glog/archive/v0.4.0.tar.gz"],
)
http_archive(
name = "zlib",
build_file = "@com_google_protobuf//:third_party/zlib.BUILD",
sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
strip_prefix = "zlib-1.2.11",
urls = ["https://zlib.net/zlib-1.2.11.tar.gz"],
)
http_archive(
name = "gemmlowp",
sha256 = "6678b484d929f2d0d3229d8ac4e3b815a950c86bb9f17851471d143f6d4f7834",
strip_prefix = "gemmlowp-12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3",
urls = [
"http://mirror.tensorflow.org/github.com/google/gemmlowp/archive/12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3.zip",
"https://github.com/google/gemmlowp/archive/12fed0cd7cfcd9e169bf1925bc3a7a58725fdcc3.zip",
],
)
#-----------------------------------------------------------------------------
# proto
#-----------------------------------------------------------------------------
# proto_library, cc_proto_library and java_proto_library rules implicitly depend
# on @com_google_protobuf//:proto, @com_google_protobuf//:cc_toolchain and
# @com_google_protobuf//:java_toolchain, respectively.
# This statement defines the @com_google_protobuf repo.
http_archive(
name = "com_google_protobuf",
strip_prefix = "protobuf-3.8.0",
urls = ["https://github.com/google/protobuf/archive/v3.8.0.zip"],
sha256 = "1e622ce4b84b88b6d2cdf1db38d1a634fe2392d74f0b7b74ff98f3a51838ee53",
)
# java_lite_proto_library rules implicitly depend on
# @com_google_protobuf_javalite//:javalite_toolchain, which is the JavaLite proto
# runtime (base classes and common utilities).
http_archive(
name = "com_google_protobuf_javalite",
strip_prefix = "protobuf-384989534b2246d413dbcd750744faab2607b516",
urls = ["https://github.com/google/protobuf/archive/384989534b2246d413dbcd750744faab2607b516.zip"],
sha256 = "79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc",
)
#
# http_archive(
# name = "com_google_protobuf",
# strip_prefix = "protobuf-master",
# urls = ["https://github.com/protocolbuffers/protobuf/archive/master.zip"],
# )
# Needed by TensorFlow
http_archive(
name = "io_bazel_rules_closure",
sha256 = "e0a111000aeed2051f29fcc7a3f83be3ad8c6c93c186e64beb1ad313f0c7f9f9",
strip_prefix = "rules_closure-cf1e44edb908e9616030cc83d085989b8e6cd6df",
urls = [
"http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz", # 2019-04-04
],
)
# TensorFlow r1.14-rc0
http_archive(
name = "org_tensorflow",
strip_prefix = "tensorflow-1.14.0-rc0",
sha256 = "76404a6157a45e8d7a07e4f5690275256260130145924c2a7c73f6eda2a3de10",
urls = ["https://github.com/tensorflow/tensorflow/archive/v1.14.0-rc0.zip"],
)
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
tf_workspace(tf_repo_name = "org_tensorflow")
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mobile_lstd_tflite_client.h"
#include <glog/logging.h>
namespace lstm_object_detection {
namespace tflite {
std::unique_ptr<MobileLSTDTfLiteClient> MobileLSTDTfLiteClient::Create() {
auto client = absl::make_unique<MobileLSTDTfLiteClient>();
if (!client->InitializeClient(CreateDefaultOptions())) {
LOG(ERROR) << "Failed to initialize client";
return nullptr;
}
return client;
}
protos::ClientOptions MobileLSTDTfLiteClient::CreateDefaultOptions() {
const int kMaxDetections = 100;
const int kClassesPerDetection = 1;
const double kScoreThreshold = -2.0;
const double kIouThreshold = 0.5;
protos::ClientOptions options;
options.set_max_detections(kMaxDetections);
options.set_max_categories(kClassesPerDetection);
options.set_score_threshold(kScoreThreshold);
options.set_iou_threshold(kIouThreshold);
options.set_agnostic_mode(false);
options.set_quantize(false);
options.set_num_keypoints(0);
return options;
}
std::unique_ptr<MobileLSTDTfLiteClient> MobileLSTDTfLiteClient::Create(
const protos::ClientOptions& options) {
auto client = absl::make_unique<MobileLSTDTfLiteClient>();
if (!client->InitializeClient(options)) {
LOG(ERROR) << "Failed to initialize client";
return nullptr;
}
return client;
}
bool MobileLSTDTfLiteClient::InitializeInterpreter(
const protos::ClientOptions& options) {
if (options.prefer_nnapi_delegate()) {
LOG(ERROR) << "NNAPI not supported.";
return false;
} else {
interpreter_->UseNNAPI(false);
}
// Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
// raw_inputs/init_lstm_h
if (interpreter_->inputs().size() != 3) {
LOG(ERROR) << "Invalid number of interpreter inputs: " <<
interpreter_->inputs().size();
return false;
}
const std::vector<int> input_tensor_indices = interpreter_->inputs();
const TfLiteTensor& input_lstm_c =
*interpreter_->tensor(input_tensor_indices[1]);
if (input_lstm_c.dims->size != 4) {
LOG(ERROR) << "Invalid input lstm_c dimensions: " <<
input_lstm_c.dims->size;
return false;
}
if (input_lstm_c.dims->data[0] != 1) {
LOG(ERROR) << "Invalid input lstm_c batch size: " <<
input_lstm_c.dims->data[0];
return false;
}
lstm_state_width_ = input_lstm_c.dims->data[1];
lstm_state_height_ = input_lstm_c.dims->data[2];
lstm_state_depth_ = input_lstm_c.dims->data[3];
lstm_state_size_ = lstm_state_width_ * lstm_state_height_ * lstm_state_depth_;
const TfLiteTensor& input_lstm_h =
*interpreter_->tensor(input_tensor_indices[2]);
if (!ValidateStateTensor(input_lstm_h, "input lstm_h")) {
return false;
}
// Outputs are: raw_outputs/box_encodings, raw_outputs/class_predictions,
// raw_outputs/lstm_c, raw_outputs/lstm_h
if (interpreter_->outputs().size() != 4) {
LOG(ERROR) << "Invalid number of interpreter outputs: " <<
interpreter_->outputs().size();
return false;
}
const std::vector<int> output_tensor_indices = interpreter_->outputs();
const TfLiteTensor& output_lstm_c =
*interpreter_->tensor(output_tensor_indices[2]);
if (!ValidateStateTensor(output_lstm_c, "output lstm_c")) {
return false;
}
const TfLiteTensor& output_lstm_h =
*interpreter_->tensor(output_tensor_indices[3]);
if (!ValidateStateTensor(output_lstm_h, "output lstm_h")) {
return false;
}
// Initialize state with all zeroes.
lstm_c_data_.resize(lstm_state_size_);
lstm_h_data_.resize(lstm_state_size_);
if (interpreter_->AllocateTensors() != kTfLiteOk) {
LOG(ERROR) << "Failed to allocate tensors";
return false;
}
return true;
}
bool MobileLSTDTfLiteClient::ValidateStateTensor(const TfLiteTensor& tensor,
const std::string& name) {
if (tensor.dims->size != 4) {
LOG(ERROR) << "Invalid " << name << " dimensions: " << tensor.dims->size;
return false;
}
if (tensor.dims->data[0] != 1) {
LOG(ERROR) << "Invalid " << name << " batch size: " << tensor.dims->data[0];
return false;
}
if (tensor.dims->data[1] != lstm_state_width_ ||
tensor.dims->data[2] != lstm_state_height_ ||
tensor.dims->data[3] != lstm_state_depth_) {
LOG(ERROR) << "Invalid " << name << " dimensions: [" <<
tensor.dims->data[0] << ", " << tensor.dims->data[1] << ", " <<
tensor.dims->data[2] << ", " << tensor.dims->data[3] << "]";
return false;
}
return true;
}
bool MobileLSTDTfLiteClient::ComputeOutputLayerCount() {
// Outputs are: raw_outputs/box_encodings, raw_outputs/class_predictions,
// raw_outputs/lstm_c, raw_outputs/lstm_h
CHECK_EQ(interpreter_->outputs().size(), 4);
num_output_layers_ = 1;
return true;
}
bool MobileLSTDTfLiteClient::FloatInference(const uint8_t* input_data) {
// Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
// raw_inputs/init_lstm_h
CHECK(input_data) << "Input data cannot be null.";
float* input = interpreter_->typed_input_tensor<float>(0);
CHECK(input) << "Input tensor cannot be null.";
// Normalize the uint8 input image with mean_value_, std_value_.
NormalizeInputImage(input_data, input);
// Copy input LSTM state into TFLite's input tensors.
float* lstm_c_input = interpreter_->typed_input_tensor<float>(1);
CHECK(lstm_c_input) << "Input lstm_c tensor cannot be null.";
std::copy(lstm_c_data_.begin(), lstm_c_data_.end(), lstm_c_input);
float* lstm_h_input = interpreter_->typed_input_tensor<float>(2);
CHECK(lstm_h_input) << "Input lstm_h tensor cannot be null.";
std::copy(lstm_h_data_.begin(), lstm_h_data_.end(), lstm_h_input);
// Run inference on inputs.
CHECK_EQ(interpreter_->Invoke(), kTfLiteOk) << "Invoking interpreter failed.";
// Copy LSTM state out of TFLite's output tensors.
// Outputs are: raw_outputs/box_encodings, raw_outputs/class_predictions,
// raw_outputs/lstm_c, raw_outputs/lstm_h
float* lstm_c_output = interpreter_->typed_output_tensor<float>(2);
CHECK(lstm_c_output) << "Output lstm_c tensor cannot be null.";
std::copy(lstm_c_output, lstm_c_output + lstm_state_size_,
lstm_c_data_.begin());
float* lstm_h_output = interpreter_->typed_output_tensor<float>(3);
CHECK(lstm_h_output) << "Output lstm_h tensor cannot be null.";
std::copy(lstm_h_output, lstm_h_output + lstm_state_size_,
lstm_h_data_.begin());
return true;
}
bool MobileLSTDTfLiteClient::QuantizedInference(const uint8_t* input_data) {
// Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
// raw_inputs/init_lstm_h
CHECK(input_data) << "Input data cannot be null.";
uint8_t* input = interpreter_->typed_input_tensor<uint8_t>(0);
CHECK(input) << "Input tensor cannot be null.";
// Copy input LSTM state into TFLite's input tensors.
uint8_t* lstm_c_input = interpreter_->typed_input_tensor<uint8_t>(1);
CHECK(lstm_c_input) << "Input lstm_c tensor cannot be null.";
std::copy(lstm_c_data_uint8_.begin(), lstm_c_data_uint8_.end(), lstm_c_input);
uint8_t* lstm_h_input = interpreter_->typed_input_tensor<uint8_t>(2);
CHECK(lstm_h_input) << "Input lstm_h tensor cannot be null.";
std::copy(lstm_h_data_uint8_.begin(), lstm_h_data_uint8_.end(), lstm_h_input);
// Run inference on inputs.
CHECK_EQ(interpreter_->Invoke(), kTfLiteOk) << "Invoking interpreter failed.";
// Copy LSTM state out of TFLite's output tensors.
// Outputs are: raw_outputs/box_encodings, raw_outputs/class_predictions,
// raw_outputs/lstm_c, raw_outputs/lstm_h
uint8_t* lstm_c_output = interpreter_->typed_output_tensor<uint8_t>(2);
CHECK(lstm_c_output) << "Output lstm_c tensor cannot be null.";
std::copy(lstm_c_output, lstm_c_output + lstm_state_size_,
lstm_c_data_uint8_.begin());
uint8_t* lstm_h_output = interpreter_->typed_output_tensor<uint8_t>(3);
CHECK(lstm_h_output) << "Output lstm_h tensor cannot be null.";
std::copy(lstm_h_output, lstm_h_output + lstm_state_size_,
lstm_h_data_uint8_.begin());
return true;
}
bool MobileLSTDTfLiteClient::Inference(const uint8_t* input_data) {
if (input_data == nullptr) {
LOG(ERROR) << "input_data cannot be null for inference.";
return false;
}
if (IsQuantizedModel())
return QuantizedInference(input_data);
else
return FloatInference(input_data);
return true;
}
} // namespace tflite
} // namespace lstm_object_detection
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_LSTM_OBJECT_DETECTION_TFLITE_MOBILE_LSTD_TFLITE_CLIENT_H_
#define TENSORFLOW_MODELS_LSTM_OBJECT_DETECTION_TFLITE_MOBILE_LSTD_TFLITE_CLIENT_H_
#include <memory>
#include <vector>
#include <cstdint>
#include "mobile_ssd_client.h"
#include "mobile_ssd_tflite_client.h"
namespace lstm_object_detection {
namespace tflite {
// Client for LSTD MobileNet TfLite model.
class MobileLSTDTfLiteClient : public MobileSSDTfLiteClient {
public:
MobileLSTDTfLiteClient() = default;
// Create with default options.
static std::unique_ptr<MobileLSTDTfLiteClient> Create();
static std::unique_ptr<MobileLSTDTfLiteClient> Create(
const protos::ClientOptions& options);
~MobileLSTDTfLiteClient() override = default;
static protos::ClientOptions CreateDefaultOptions();
protected:
bool InitializeInterpreter(const protos::ClientOptions& options) override;
bool ComputeOutputLayerCount() override;
bool Inference(const uint8_t* input_data) override;
private:
// MobileLSTDTfLiteClient is neither copyable nor movable.
MobileLSTDTfLiteClient(const MobileLSTDTfLiteClient&) = delete;
MobileLSTDTfLiteClient& operator=(const MobileLSTDTfLiteClient&) = delete;
bool ValidateStateTensor(const TfLiteTensor& tensor, const std::string& name);
// Helper functions used by Inference functions.
bool FloatInference(const uint8_t* input_data);
bool QuantizedInference(const uint8_t* input_data);
// LSTM model parameters.
int lstm_state_width_ = 0;
int lstm_state_height_ = 0;
int lstm_state_depth_ = 0;
int lstm_state_size_ = 0;
// LSTM state stored between float inference runs.
std::vector<float> lstm_c_data_;
std::vector<float> lstm_h_data_;
// LSTM state stored between uint8 inference runs.
std::vector<uint8_t> lstm_c_data_uint8_;
std::vector<uint8_t> lstm_h_data_uint8_;
};
} // namespace tflite
} // namespace lstm_object_detection
#endif // TENSORFLOW_MODELS_LSTM_OBJECT_DETECTION_TFLITE_MOBILE_LSTD_TFLITE_CLIENT_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mobile_ssd_client.h"
#include <stdlib.h>
#include <map>
#include <glog/logging.h>
#include "absl/memory/memory.h"
#include "utils/conversion_utils.h"
#include "utils/ssd_utils.h"
namespace lstm_object_detection {
namespace tflite {
bool MobileSSDClient::InitializeClient(const protos::ClientOptions& options) {
options_ = options;
return true;
}
bool MobileSSDClient::Detect(const uint8_t* pixels, int width, int height,
int bytes_per_pixel, int bytes_per_row,
protos::DetectionResults* detections) {
SetInputDims(width, height);
// Grayscale input images are only compatible with grayscale models, and
// color input images are only compatible with color models.
CHECK((bytes_per_pixel == 1 && input_depth_ == 1) ||
(bytes_per_pixel >= 3 && input_depth_ >= 3));
if (HasPadding(width, height, bytes_per_pixel, bytes_per_row)) {
std::vector<uint8_t> unpadded_pixels =
RemovePadding(pixels, width, height, bytes_per_pixel, bytes_per_row);
return Detect(&unpadded_pixels[0], detections);
} else {
return Detect(pixels, detections);
}
}
bool MobileSSDClient::Detect(const uint8_t* pixels,
protos::DetectionResults* detections) {
return BatchDetect(pixels, 1, absl::MakeSpan(&detections, 1));
}
bool MobileSSDClient::BatchDetect(
const uint8_t* pixels, int batch_size,
absl::Span<protos::DetectionResults*> detections) {
if (detections.size() != batch_size) {
LOG(ERROR) << "Batch size does not match output cardinality.";
return false;
}
if (batch_size != batch_size_) {
if (!SetBatchSize(batch_size)) {
LOG(ERROR) << "Couldn't set batch size.";
return false;
}
}
if (!Inference(pixels)) {
LOG(ERROR) << "Couldn't inference.";
return false;
}
for (int batch = 0; batch < batch_size; ++batch) {
if (RequiresPostProcessing()) {
LOG(ERROR) << "Post Processing not supported.";
return false;
} else {
if (NoPostProcessNoAnchors(detections[batch])) {
LOG(ERROR) << "NoPostProcessNoAnchors failed.";
return false;
}
}
}
return true;
}
bool MobileSSDClient::SetBatchSize(int batch_size) {
batch_size_ = batch_size;
AllocateBuffers();
if (batch_size != 1) {
LOG(ERROR)
<< "Only single batch inference supported by default. All child "
"classes that support batched inference should override this method "
"and not return an error if the batch size is supported. (E.g. "
"MobileSSDTfLiteClient).";
return false;
}
return true;
}
bool MobileSSDClient::NoPostProcessNoAnchors(
protos::DetectionResults* detections) {
LOG(ERROR) << "not yet implemented";
return false;
}
bool MobileSSDClient::RequiresPostProcessing() const {
return anchors_.y_size() > 0;
}
void MobileSSDClient::SetInputDims(int width, int height) {
CHECK_EQ(width, input_width_);
CHECK_EQ(height, input_height_);
}
int MobileSSDClient::GetNumberOfLabels() const { return labelmap_.item_size(); }
std::string MobileSSDClient::GetLabelDisplayName(const int class_index) const {
if (class_index < 0 || class_index >= GetNumberOfLabels()) {
return "";
}
return labelmap_.item(class_index).display_name();
}
std::string MobileSSDClient::GetLabelName(const int class_index) const {
if (class_index < 0 || class_index >= GetNumberOfLabels()) {
return "";
}
return labelmap_.item(class_index).name();
}
int MobileSSDClient::GetLabelId(const int class_index) const {
if (class_index < 0 || class_index >= GetNumberOfLabels() ||
!labelmap_.item(class_index).has_id()) {
return -1;
}
return labelmap_.item(class_index).id();
}
void MobileSSDClient::SetLabelDisplayNameInResults(
protos::DetectionResults* detections) {
for (auto& det : *detections->mutable_detection()) {
for (const auto& class_index : det.class_index()) {
det.add_display_name(GetLabelDisplayName(class_index));
}
}
}
void MobileSSDClient::SetLabelNameInResults(
protos::DetectionResults* detections) {
for (auto& det : *detections->mutable_detection()) {
for (const auto& class_index : det.class_index()) {
det.add_class_name(GetLabelName(class_index));
}
}
}
void MobileSSDClient::InitParams(const bool agnostic_mode,
const bool quantize,
const int num_keypoints) {
num_keypoints_ = num_keypoints;
code_size_ = 4 + 2 * num_keypoints;
num_boxes_ = output_locations_size_ / code_size_;
if (agnostic_mode) {
num_classes_ = output_scores_size_ / num_boxes_;
} else {
num_classes_ = (output_scores_size_ / num_boxes_) - 1;
}
quantize_ = quantize;
AllocateBuffers();
}
void MobileSSDClient::AllocateBuffers() {
// Allocate the output vectors
output_locations_.resize(output_locations_size_ * batch_size_);
output_scores_.resize(output_scores_size_ * batch_size_);
if (quantize_) {
quantized_output_pointers_ =
absl::make_unique<std::vector<std::unique_ptr<std::vector<uint8_t>>>>(
batch_size_ * num_output_layers_ * 2);
for (int batch = 0; batch < batch_size_; ++batch) {
for (int i = 0; i < num_output_layers_; ++i) {
quantized_output_pointers_->at(2 * (i + batch * num_output_layers_)) =
absl::make_unique<std::vector<uint8_t>>(output_locations_sizes_[i]);
quantized_output_pointers_->at(2 * (i + batch * num_output_layers_) +
1) =
absl::make_unique<std::vector<uint8_t>>(output_scores_sizes_[i]);
}
}
quantized_output_pointers_array_.reset(
new uint8_t*[batch_size_ * num_output_layers_ * 2]);
for (int i = 0; i < batch_size_ * num_output_layers_ * 2; ++i) {
quantized_output_pointers_array_[i] =
quantized_output_pointers_->at(i)->data();
}
gemm_context_.set_max_num_threads(1);
} else {
output_pointers_[0] = output_locations_.data();
output_pointers_[1] = output_scores_.data();
}
}
} // namespace tflite
} // namespace lstm_object_detection
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_LSTM_OBJECT_DETECTION_TFLITE_MOBILE_SSD_CLIENT_H_
#define TENSORFLOW_MODELS_LSTM_OBJECT_DETECTION_TFLITE_MOBILE_SSD_CLIENT_H_
#include <memory>
#include <vector>
#include <cstdint>
#include "absl/types/span.h"
#include "public/gemmlowp.h"
#include "protos/box_encodings.pb.h"
#include "protos/detections.pb.h"
#include "protos/labelmap.pb.h"
#include "protos/mobile_ssd_client_options.pb.h"
namespace lstm_object_detection {
namespace tflite {
// MobileSSDClient base class. Not thread-safe.
class MobileSSDClient {
public:
MobileSSDClient() = default;
virtual ~MobileSSDClient() = default;
// Runs detection on the image represented by 'pixels', described by the
// associated 'width', 'height', 'bytes_per_pixel' and 'bytes_per_row'. All
// these integers must be positive, 'bytes_per_row' must be sufficiently
// large, and for 'bytes_per_pixel' only values 1, 3, 4 may be passed.
// Depending on the implementation most combinations may not be allowed.
bool Detect(const uint8_t* pixels, int width, int height, int bytes_per_pixel,
int bytes_per_row, protos::DetectionResults* detections);
// Same as before, but a contiguous bytewise encoding of 'pixels' is assumed.
// That encoding can be assigned directly to the input layer of the neural
// network.
bool Detect(const uint8_t* pixels, protos::DetectionResults* detections);
// Runs batched inference on the provided buffer. "pixels" is assumed to be a
// continuous buffer of width * height * depth * batch_size pixels. It will
// populate the detections result with batch_size DetectionResults where the
// first result corresponds to the first image contained within the pixels
// block. Note that not all models generalize correctly to multi-batch
// inference and in some cases the addition of extra batches may corrupt the
// output on the model. For example, if a network performs operations across
// batches, BatchDetect([A, B]) may not equal [Detect(A), Detect(B)].
bool BatchDetect(const uint8_t* pixels, int batch_size,
absl::Span<protos::DetectionResults*> detections);
// Sets the dimensions of the input image on the fly, to be effective for the
// next Detect() call.
void SetInputDims(int width, int height);
// Returns the width of the input image which is always positive. Usually a
// constant or the width last set via 'SetInputDims()'.
int GetInputWidth() const { return input_width_; }
// Returns the height of the input image which is always positive. Usually a
// constant or the width last set via 'SetInputDims()'.
int GetInputHeight() const { return input_height_; }
// Returns the depth of the input image, which is the same as bytes per pixel.
// This will be 3 (for RGB images), 4 (for RGBA images), or 1 (for grayscale
// images).
int GetInputDepth() const { return input_depth_; }
// Returns the number of possible detection labels or classes. If
// agnostic_mode is on, then this method must return 1.
int GetNumberOfLabels() const;
// Returns human readable class labels given predicted class index. The range
// of 'label_index' is determined by 'GetNumberOfLabels()'. Returns an empty
// string if the label display name is undefined or 'label_index' is out of
// range.
std::string GetLabelDisplayName(const int class_index) const;
// Returns Knowledge Graph MID class labels given predicted class index. The
// range of 'label_index' is determined by 'GetNumberOfLabels()'. Returns an
// empty string if the label name is undefined or 'label_index' is out of
// range.
std::string GetLabelName(const int class_index) const;
// Returns the class/label ID for a given predicted class index. The range of
// 'label_index' is determined by 'GetNumberOfLabels()'. Returns -1 in case
// 'label_index' is out of range.
int GetLabelId(const int class_index) const;
// Explicitly sets human readable string class name to each detection using
// the `display_name` field.
void SetLabelDisplayNameInResults(protos::DetectionResults* detections);
// Explicitly sets string class name to each detection using the `class_name`
// fields.
void SetLabelNameInResults(protos::DetectionResults* detections);
protected:
// Initializes the client from options.
virtual bool InitializeClient(const protos::ClientOptions& options);
// Initializes various model specific parameters.
virtual void InitParams() {
InitParams(false, false, 0);
}
virtual void InitParams(const bool agnostic_mode,
const bool quantize,
const int num_keypoints);
virtual void InitParams(const bool agnostic_mode, const bool quantize,
const int num_keypoints,
const protos::BoxCoder& coder) {
InitParams(agnostic_mode, quantize, num_keypoints);
*options_.mutable_box_coder() = coder;
}
virtual void AllocateBuffers();
// Sets the batch size of inference. If reimplmented, overrider is responsible
// for calling parent (the returned status code may be ignored).
virtual bool SetBatchSize(int batch_size);
// Perform client specific inference on input_data.
virtual bool Inference(const uint8_t* input_data) = 0;
// Directly populates the results when no post-processing should take place
// and no anchors are present. This is only possible when the TensorFlow
// graph contains the customized post-processing ops.
virtual bool NoPostProcessNoAnchors(protos::DetectionResults* detections);
// Returns true iff the model returns raw output and needs its results
// post-processed (including non-maximum suppression). If false then anchors
// do not need to be present, LoadAnchors() can be implemented empty. Note
// that almost all models around require post-processing.
bool RequiresPostProcessing() const;
// Load client specific labelmap proto file.
virtual void LoadLabelMap() = 0;
// Anchors for the model.
protos::CenterSizeEncoding anchors_;
// Labelmap for the model.
protos::StringIntLabelMapProto labelmap_;
// Options for the model.
protos::ClientOptions options_;
// Buffers for storing the model predictions
float* output_pointers_[2];
// The dimension of output_locations is [batch_size x num_anchors x 4]
std::vector<float> output_locations_;
// The dimension of output_scores is:
// If background class is included:
// [batch_size x num_anchors x (num_classes + 1)]
// If background class is NOT included:
// [batch_size x num_anchors x num_classes]
std::vector<float> output_scores_;
void* transient_data_;
// Total location and score sizes.
int output_locations_size_;
int output_scores_size_;
// Output location and score sizes for each output layer.
std::vector<int> output_locations_sizes_;
std::vector<int> output_scores_sizes_;
// Preproccessing related parameters
float mean_value_;
float std_value_;
std::vector<int> location_zero_points_;
std::vector<float> location_scales_;
std::vector<int> score_zero_points_;
std::vector<float> score_scales_;
int num_output_layers_ = 1;
// Model related parameters
int input_size_;
int num_classes_;
int num_boxes_;
int input_width_;
int input_height_;
int input_depth_ = 3; // Default value is set for backward compatibility.
int code_size_;
int batch_size_ = 1; // Default value is set for backwards compatibility.
// The number of keypoints by detection. Specific to faces for now.
int num_keypoints_;
// Whether to use the quantized model.
bool quantize_;
// The indices of restricted classes (empty if none was passed in the config).
std::vector<int> restricted_class_indices_;
// Buffers for storing quantized model predictions
std::unique_ptr<std::vector<std::unique_ptr<std::vector<uint8_t>>>>
quantized_output_pointers_;
std::unique_ptr<uint8_t*[]> quantized_output_pointers_array_;
gemmlowp::GemmContext gemm_context_;
};
} // namespace tflite
} // namespace lstm_object_detection
#endif // TENSORFLOW_MODELS_LSTM_OBJECT_DETECTION_TFLITE_MOBILE_SSD_CLIENT_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mobile_ssd_tflite_client.h"
#include <glog/logging.h>
#include "tensorflow/lite/arena_planner.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/register.h"
#include "utils/file_utils.h"
#include "utils/ssd_utils.h"
namespace lstm_object_detection {
namespace tflite {
namespace {
constexpr int kInputBatch = 1;
constexpr int kInputDepth = 1;
constexpr int kNumBoundingBoxCoordinates = 4; // xmin, ymin, width, height
constexpr int GetBoxIndex(const int layer) { return (2 * layer); }
constexpr int GetScoreIndex(const int layer) { return (2 * layer + 1); }
} // namespace
MobileSSDTfLiteClient::MobileSSDTfLiteClient() {}
std::unique_ptr<::tflite::OpResolver>
MobileSSDTfLiteClient::CreateOpResolver() {
return absl::make_unique<::tflite::ops::builtin::BuiltinOpResolver>();
}
bool MobileSSDTfLiteClient::InitializeClient(
const protos::ClientOptions& options) {
if (!MobileSSDClient::InitializeClient(options)) {
return false;
}
if (options.has_external_files()) {
if (options.external_files().model_file_name().empty() &&
options.external_files().model_file_content().empty()) {
LOG(ERROR)
<< "MobileSSDClient: both `external_files.model_file_name` and "
"`external_files.model_file_content` are empty which is invalid.";
}
if (!options_.external_files().model_file_content().empty()) {
model_ = ::tflite::FlatBufferModel::BuildFromBuffer(
options_.external_files().model_file_content().data(),
options_.external_files().model_file_content().size());
} else {
const char* tflite_model_filename = reinterpret_cast<const char*>(
options_.external_files().model_file_name().c_str());
model_ = ::tflite::FlatBufferModel::BuildFromFile(tflite_model_filename);
}
} else {
LOG(ERROR) << "Embedded model is not supported.";
return false;
}
if (!model_) {
LOG(ERROR) << "Failed to load model";
return false;
}
LoadLabelMap();
resolver_ = CreateOpResolver();
::tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter_);
if (!interpreter_) {
LOG(ERROR) << "Failed to build interpreter";
return false;
}
if (!InitializeInterpreter(options)) {
LOG(ERROR) << "Failed to initialize interpreter";
return false;
}
if (RequiresPostProcessing() && !ComputeOutputSize()) {
LOG(ERROR) << "Failed to compute output size";
return false;
}
// Initializes number of boxes, number of keypoints, quantized model flag and
// allocates output arrays based on output size computed by
// ComputeOutputSize()
agnostic_mode_ = options.agnostic_mode();
if (!restricted_class_indices_.empty()) {
LOG(ERROR) << "Restricted class unsupported.";
return false;
}
// Default num_keypoints will be overridden by value specified by
// GetNumberOfKeypoints()
const int num_keypoints = GetNumberOfKeypoints();
// Other parameters are not needed and do not make sense when the model
// contains the post-processing ops. Avoid init altogether in this case.
if (RequiresPostProcessing()) {
InitParams(IsAgnosticMode(), IsQuantizedModel(), num_keypoints,
GetBoxCoder());
}
SetImageNormalizationParams();
// Getting shape of input tensors. This also checks for size consistency with
// anchors. It also makes input_width_ and input_height_ available to
// LoadAnchors
if (!SetInputShape()) {
LOG(ERROR) << "Failed to set input shape";
return false;
}
// Output sizes are compared to expect sizes based on number of anchors,
// number of classes, number of key points and number of values used to
// represent a bounding box.
if (RequiresPostProcessing() && !CheckOutputSizes()) {
LOG(ERROR) << "Check for output size failed";
return false;
}
SetZeroPointsAndScaleFactors(quantize_);
LOG(INFO) << "Model initialized:"
<< " input_size: " << input_size_
<< ", output_locations_size: " << output_locations_size_
<< ", preprocessing mean value: " << mean_value_
<< ", preprocessing std value: " << std_value_;
return true;
}
void MobileSSDTfLiteClient::SetImageNormalizationParams() {
mean_value_ = 127.5f;
std_value_ = 127.5f;
}
int MobileSSDTfLiteClient::GetNumberOfKeypoints() const {
return options_.num_keypoints();
}
bool MobileSSDTfLiteClient::SetInputShape() {
// inputs() maps the input tensor index to the index TFLite's tensors
const int input_tensor_index = interpreter_->inputs()[0];
const TfLiteTensor* input_tensor = interpreter_->tensor(input_tensor_index);
if ((input_tensor->type != kTfLiteUInt8) &&
(input_tensor->type != kTfLiteFloat32)) {
LOG(ERROR) << "Unsupported tensor input type: " << input_tensor->type;
return false;
}
if (input_tensor->dims->size != 4) {
LOG(ERROR) << "Expected input tensor dimension size to be 4, got "
<< input_tensor->dims->size;
return false;
}
input_depth_ = input_tensor->dims->data[3];
input_width_ = input_tensor->dims->data[2];
input_height_ = input_tensor->dims->data[1];
input_size_ = input_height_ * input_width_ * input_depth_ * batch_size_;
return true;
}
bool MobileSSDTfLiteClient::InitializeInterpreter(
const protos::ClientOptions& options) {
if (options.prefer_nnapi_delegate()) {
LOG(ERROR) << "NNAPI not supported.";
return false;
}
interpreter_->UseNNAPI(false);
if (options.num_threads() > 0) {
interpreter_->SetNumThreads(options.num_threads());
}
if (interpreter_->inputs().size() != 1) {
LOG(ERROR) << "Invalid number of interpreter inputs: "
<< interpreter_->inputs().size();
return false;
}
if (interpreter_->AllocateTensors() != kTfLiteOk) {
LOG(ERROR) << "Failed to allocate tensors!";
return false;
}
return true;
}
bool MobileSSDTfLiteClient::CheckOutputSizes() {
int expected_output_locations_size =
anchors_.y_size() * (kNumBoundingBoxCoordinates + 2 * num_keypoints_);
if (output_locations_size_ != expected_output_locations_size) {
LOG(ERROR)
<< "The dimension of output_locations must be [num_anchors x 4]. Got "
<< output_locations_size_ << " but expected "
<< expected_output_locations_size;
return false;
}
// Include background class score when not in agnostic mode
int expected_output_scores_size =
anchors_.y_size() * (labelmap_.item_size() + (IsAgnosticMode() ? 0 : 1));
if (output_scores_size_ != expected_output_scores_size) {
LOG(ERROR)
<< "The dimension of output_scores is: "
"[num_anchors x (num_classes + 1)] if background class is included. "
"[num_anchors x num_classes] if background class is not included. "
"Got "
<< output_scores_size_ << " but expected "
<< expected_output_scores_size;
return false;
}
return true;
}
bool MobileSSDTfLiteClient::IsQuantizedModel() const {
const int input_tensor_index = interpreter_->inputs()[0];
const TfLiteTensor* input_tensor = interpreter_->tensor(input_tensor_index);
return input_tensor->type == kTfLiteUInt8;
}
void MobileSSDTfLiteClient::SetZeroPointsAndScaleFactors(
bool is_quantized_model) {
// Sets initial scale to 1 and zero_points to 0. These values are only
// written over in quantized model case.
location_zero_points_.assign(num_output_layers_, 0);
location_scales_.assign(num_output_layers_, 1);
score_zero_points_.assign(num_output_layers_, 0);
score_scales_.assign(num_output_layers_, 1);
// Set scale and zero_point for quantized model
if (is_quantized_model) {
for (int layer = 0; layer < num_output_layers_; ++layer) {
const int location_tensor_index =
interpreter_->outputs()[GetBoxIndex(layer)];
const TfLiteTensor* location_tensor =
interpreter_->tensor(location_tensor_index);
location_zero_points_[layer] = location_tensor->params.zero_point;
location_scales_[layer] = location_tensor->params.scale;
// Class Scores
const int score_tensor_index =
interpreter_->outputs()[GetScoreIndex(layer)];
const TfLiteTensor* score_tensor =
interpreter_->tensor(score_tensor_index);
score_zero_points_[layer] = score_tensor->params.zero_point;
score_scales_[layer] = score_tensor->params.scale;
}
}
}
bool MobileSSDTfLiteClient::ComputeOutputLocationsSize(
const TfLiteTensor* location_tensor, int layer) {
const int location_tensor_size = location_tensor->dims->size;
if (location_tensor_size == 3) {
const int location_code_size = location_tensor->dims->data[2];
const int location_num_anchors = location_tensor->dims->data[1];
output_locations_sizes_[layer] = location_code_size * location_num_anchors;
} else if (location_tensor_size == 4) {
const int location_depth = location_tensor->dims->data[3];
const int location_width = location_tensor->dims->data[2];
const int location_height = location_tensor->dims->data[1];
output_locations_sizes_[layer] =
location_depth * location_width * location_height;
} else {
LOG(ERROR) << "Expected location_tensor_size of 3 or 4, got "
<< location_tensor_size;
return false;
}
return true;
}
bool MobileSSDTfLiteClient::ComputeOutputScoresSize(
const TfLiteTensor* score_tensor, int layer) {
const int score_tensor_size = score_tensor->dims->size;
if (score_tensor_size == 3) {
const int score_num_classes = score_tensor->dims->data[2];
const int score_num_anchors = score_tensor->dims->data[1];
output_scores_sizes_[layer] = score_num_classes * score_num_anchors;
} else if (score_tensor_size == 4) {
const int score_depth = score_tensor->dims->data[3];
const int score_width = score_tensor->dims->data[2];
const int score_height = score_tensor->dims->data[1];
output_scores_sizes_[layer] = score_depth * score_width * score_height;
} else {
LOG(ERROR) << "Expected score_tensor_size of 3 or 4, got "
<< score_tensor_size;
return false;
}
return true;
}
bool MobileSSDTfLiteClient::ComputeOutputLayerCount() {
// Compute number of layers in the output model
const int num_outputs = interpreter_->outputs().size();
if (num_outputs == 0) {
LOG(ERROR) << "Number of outputs cannot be zero.";
return false;
}
if (num_outputs % 2 != 0) {
LOG(ERROR) << "Number of outputs must be evenly divisible by 2. Actual "
"number of outputs: "
<< num_outputs;
return false;
}
num_output_layers_ = num_outputs / 2;
return true;
}
bool MobileSSDTfLiteClient::ComputeOutputSize() {
if (!ComputeOutputLayerCount()) {
return false;
}
// Allocate output arrays for box location and class scores
output_locations_sizes_.resize(num_output_layers_);
output_scores_sizes_.resize(num_output_layers_);
output_locations_size_ = 0;
output_scores_size_ = 0;
// This loop calculates the total size of data occupied by the output as well
// as the size for everylayer of the model. For quantized case, it also stores
// the offset and scale factor needed to transform the data back to floating
// point values.
for (int layer = 0; layer < num_output_layers_; ++layer) {
// Calculate sizes of Box locations output
const int location_tensor_index =
interpreter_->outputs()[GetBoxIndex(layer)];
const TfLiteTensor* location_tensor =
interpreter_->tensor(location_tensor_index);
if (!ComputeOutputLocationsSize(location_tensor, layer)) {
return false;
}
output_locations_size_ += output_locations_sizes_[layer];
// Class Scores
const int score_tensor_index =
interpreter_->outputs()[GetScoreIndex(layer)];
const TfLiteTensor* score_tensor = interpreter_->tensor(score_tensor_index);
if (!ComputeOutputScoresSize(score_tensor, layer)) {
return false;
}
output_scores_size_ += output_scores_sizes_[layer];
}
return true;
}
void MobileSSDTfLiteClient::NormalizeInputImage(const uint8_t* input_data,
float* normalized_input_data) {
float reciprocal_std_value_ = (1.0f / std_value_);
for (int i = 0; i < input_size_; i++, input_data++, normalized_input_data++) {
*normalized_input_data =
reciprocal_std_value_ * (static_cast<float>(*input_data) - mean_value_);
}
}
void MobileSSDTfLiteClient::GetOutputBoxesAndScoreTensorsFromFloat() {
float* output_score_pointer = output_scores_.data();
float* output_location_pointer = output_locations_.data();
for (int batch = 0; batch < batch_size_; ++batch) {
for (int layer = 0; layer < num_output_layers_; ++layer) {
// Write output location data
const float* location_data =
interpreter_->typed_output_tensor<float>(GetBoxIndex(layer)) +
batch * output_locations_sizes_[layer];
memcpy(output_location_pointer, location_data,
output_locations_sizes_[layer] * sizeof(float));
output_location_pointer += output_locations_sizes_[layer];
// Write output class scores
const float* score_data =
interpreter_->typed_output_tensor<float>(GetScoreIndex(layer)) +
batch * output_scores_sizes_[layer];
memcpy(output_score_pointer, score_data,
output_scores_sizes_[layer] * sizeof(float));
output_score_pointer += output_scores_sizes_[layer];
}
}
}
void MobileSSDTfLiteClient::GetOutputBoxesAndScoreTensorsFromUInt8() {
// The box locations and score are now convert back to floating point from
// their quantized version by shifting and scaling the output tensors on an
// element-wise basis
auto output_score_it = output_scores_.begin();
auto output_location_it = output_locations_.begin();
for (int batch = 0; batch < batch_size_; ++batch) {
for (int layer = 0; layer < num_output_layers_; ++layer) {
// Write output location data
const auto location_scale = location_scales_[layer];
const auto location_zero_point = location_zero_points_[layer];
const auto* location_data =
interpreter_->typed_output_tensor<uint8_t>(GetBoxIndex(layer));
for (int j = 0; j < output_locations_sizes_[layer];
++j, ++output_location_it) {
*output_location_it =
location_scale *
(static_cast<int>(
location_data[j + batch * output_locations_sizes_[layer]]) -
location_zero_point);
}
// write output class scores
const auto score_scale = score_scales_[layer];
const auto score_zero_point = score_zero_points_[layer];
const auto* score_data =
interpreter_->typed_output_tensor<uint8_t>(GetScoreIndex(layer));
for (int j = 0; j < output_scores_sizes_[layer]; ++j, ++output_score_it) {
*output_score_it =
score_scale *
(static_cast<int>(
score_data[j + batch * output_scores_sizes_[layer]]) -
score_zero_point);
}
}
}
}
bool MobileSSDTfLiteClient::FloatInference(const uint8_t* input_data) {
auto* input = interpreter_->typed_input_tensor<float>(0);
if (input == nullptr) {
LOG(ERROR) << "Input tensor cannot be null for inference.";
return false;
}
// The non-quantized model assumes float input
// So we normalize the uint8 input image using mean_value_
// and std_value_
NormalizeInputImage(input_data, input);
// Applies model to the data. The data will be store in the output tensors
if (interpreter_->Invoke() != kTfLiteOk) {
LOG(ERROR) << "Invoking interpreter resulted in non-okay status.";
return false;
}
// Parse outputs
if (RequiresPostProcessing()) {
GetOutputBoxesAndScoreTensorsFromFloat();
}
return true;
}
bool MobileSSDTfLiteClient::QuantizedInference(const uint8_t* input_data) {
auto* input = interpreter_->typed_input_tensor<uint8_t>(0);
if (input == nullptr) {
LOG(ERROR) << "Input tensor cannot be null for inference.";
return false;
}
memcpy(input, input_data, input_size_);
// Applies model to the data. The data will be store in the output tensors
if (interpreter_->Invoke() != kTfLiteOk) {
LOG(ERROR) << "Invoking interpreter resulted in non-okay status.";
return false;
}
// Parse outputs
if (RequiresPostProcessing()) {
GetOutputBoxesAndScoreTensorsFromUInt8();
}
return true;
}
bool MobileSSDTfLiteClient::Inference(const uint8_t* input_data) {
if (input_data == nullptr) {
LOG(ERROR) << "input_data cannot be null for inference.";
return false;
}
if (IsQuantizedModel())
return QuantizedInference(input_data);
else
return FloatInference(input_data);
return true;
}
bool MobileSSDTfLiteClient::NoPostProcessNoAnchors(
protos::DetectionResults* detections) {
const float* boxes = interpreter_->typed_output_tensor<float>(0);
const float* classes = interpreter_->typed_output_tensor<float>(1);
const float* confidences = interpreter_->typed_output_tensor<float>(2);
int num_detections =
static_cast<int>(interpreter_->typed_output_tensor<float>(3)[0]);
int max_detections = options_.max_detections() > 0 ? options_.max_detections()
: num_detections;
std::vector<int> sorted_indices;
sorted_indices.resize(num_detections);
for (int i = 0; i < num_detections; ++i) sorted_indices[i] = i;
std::sort(sorted_indices.begin(), sorted_indices.end(),
[&confidences](const int i, const int j) {
return confidences[i] > confidences[j];
});
for (int i = 0;
i < num_detections && detections->detection_size() < max_detections;
++i) {
const int index = sorted_indices[i];
if (confidences[index] < options_.score_threshold()) {
break;
}
const int class_index = classes[index];
protos::Detection* detection = detections->add_detection();
detection->add_score(confidences[index]);
detection->add_class_index(class_index);
// For some reason it is not OK to add class/label names here, they appear
// to mess up the drishti graph.
// detection->add_display_name(GetLabelDisplayName(class_index));
// detection->add_class_name(GetLabelName(class_index));
protos::BoxCornerEncoding* box = detection->mutable_box();
box->add_ymin(boxes[4 * index]);
box->add_xmin(boxes[4 * index + 1]);
box->add_ymax(boxes[4 * index + 2]);
box->add_xmax(boxes[4 * index + 3]);
}
return true;
}
bool MobileSSDTfLiteClient::SetBatchSize(int batch_size) {
if (!this->MobileSSDClient::SetBatchSize(batch_size)) {
LOG(ERROR) << "Error in SetBatchSize()";
return false;
}
input_size_ = input_height_ * input_width_ * input_depth_ * batch_size_;
for (int input : interpreter_->inputs()) {
auto* old_dims = interpreter_->tensor(input)->dims;
std::vector<int> new_dims(old_dims->data, old_dims->data + old_dims->size);
new_dims[0] = batch_size;
if (interpreter_->ResizeInputTensor(input, new_dims) != kTfLiteOk) {
LOG(ERROR) << "Unable to resize input for new batch size";
return false;
}
}
if (interpreter_->AllocateTensors() != kTfLiteOk) {
LOG(ERROR) << "Unable to reallocate tensors";
return false;
}
return true;
}
void MobileSSDTfLiteClient::LoadLabelMap() {
if (options_.has_external_files()) {
if (options_.external_files().has_label_map_file_content() ||
options_.external_files().has_label_map_file_name()) {
CHECK(LoadLabelMapFromFileOrBytes(
options_.external_files().label_map_file_name(),
options_.external_files().label_map_file_content(), &labelmap_));
} else {
LOG(ERROR) << "MobileSSDTfLiteClient: both "
"'external_files.label_map_file_content` and "
"'external_files.label_map_file_name` are empty"
" which is invalid.";
}
}
}
} // namespace tflite
} // namespace lstm_object_detection
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_LSTM_OBJECT_DETECTION_TFLITE_MOBILE_SSD_TFLITE_CLIENT_H_
#define TENSORFLOW_MODELS_LSTM_OBJECT_DETECTION_TFLITE_MOBILE_SSD_TFLITE_CLIENT_H_
#include <memory>
#include <unordered_set>
#include "absl/memory/memory.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/model.h"
#include "mobile_ssd_client.h"
#include "protos/anchor_generation_options.pb.h"
namespace lstm_object_detection {
namespace tflite {
class MobileSSDTfLiteClient : public MobileSSDClient {
public:
MobileSSDTfLiteClient();
explicit MobileSSDTfLiteClient(
std::unique_ptr<::tflite::OpResolver> resolver);
~MobileSSDTfLiteClient() override = default;
protected:
// By default CreateOpResolver will create
// tflite::ops::builtin::BuiltinOpResolver. Overriding the function allows the
// client to use custom op resolvers.
virtual std::unique_ptr<::tflite::OpResolver> CreateOpResolver();
bool InitializeClient(const protos::ClientOptions& options) override;
virtual bool InitializeInterpreter(const protos::ClientOptions& options);
virtual bool ComputeOutputLayerCount();
bool Inference(const uint8_t* input_data) override;
bool NoPostProcessNoAnchors(protos::DetectionResults* detections) override;
// Use with caution. Not all models work correctly when resized to larger
// batch sizes. This will resize the input tensor to have the given batch size
// and propagate the batch dimension throughout the graph.
bool SetBatchSize(int batch_size) override;
// This can be overridden in a subclass to load label map from file
void LoadLabelMap() override;
// This can be overridden in a subclass to return customized box coder.
virtual const protos::BoxCoder GetBoxCoder() { return protos::BoxCoder(); }
virtual void SetImageNormalizationParams();
void NormalizeInputImage(const uint8_t* input_data,
float* normalized_input_data);
void GetOutputBoxesAndScoreTensorsFromFloat();
virtual bool IsQuantizedModel() const;
std::unique_ptr<::tflite::FlatBufferModel> model_;
std::unique_ptr<::tflite::OpResolver> resolver_;
std::unique_ptr<::tflite::Interpreter> interpreter_;
private:
// MobileSSDTfLiteClient is neither copyable nor movable.
MobileSSDTfLiteClient(const MobileSSDTfLiteClient&) = delete;
MobileSSDTfLiteClient& operator=(const MobileSSDTfLiteClient&) = delete;
// Helper functions used by Initialize Client.
virtual int GetNumberOfKeypoints() const;
// Returns true if the client is in class-agnostic mode. This function can be
// overridden in a subclass to return an ad-hoc value (e.g. hard-coded).
virtual bool IsAgnosticMode() const { return agnostic_mode_; }
bool CheckOutputSizes();
bool ComputeOutputSize();
bool SetInputShape();
void SetZeroPointsAndScaleFactors(bool is_quantized_model);
bool ComputeOutputLocationsSize(const TfLiteTensor* location_tensor,
int layer);
bool ComputeOutputScoresSize(const TfLiteTensor* score_tensor, int layer);
// The agnostic_mode_ field should never be directly read. Always use its
// virtual accessor method: IsAgnosticMode().
bool agnostic_mode_;
// Helper functions used by Inference functions
bool FloatInference(const uint8_t* input_data);
bool QuantizedInference(const uint8_t* input_data);
void GetOutputBoxesAndScoreTensorsFromUInt8();
};
} // namespace tflite
} // namespace lstm_object_detection
#endif // TENSORFLOW_MODELS_LSTM_OBJECT_DETECTION_TFLITE_MOBILE_SSD_TFLITE_CLIENT_H_
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
proto_library(
name = "box_encodings_proto",
srcs = ["box_encodings.proto"],
)
cc_proto_library(
name = "box_encodings_cc_proto",
deps = [":box_encodings_proto"],
)
proto_library(
name = "detections_proto",
srcs = ["detections.proto"],
deps = [":box_encodings_proto"],
)
cc_proto_library(
name = "detections_cc_proto",
deps = [":detections_proto"],
)
proto_library(
name = "labelmap_proto",
srcs = ["labelmap.proto"],
)
cc_proto_library(
name = "labelmap_cc_proto",
deps = [":labelmap_proto"],
)
proto_library(
name = "mobile_ssd_client_options_proto",
srcs = ["mobile_ssd_client_options.proto"],
deps = [
":anchor_generation_options_proto",
":box_encodings_proto",
":labelmap_proto",
],
)
cc_proto_library(
name = "mobile_ssd_client_options_cc_proto",
deps = [":mobile_ssd_client_options_proto"],
)
proto_library(
name = "anchor_generation_options_proto",
srcs = ["anchor_generation_options.proto"],
)
cc_proto_library(
name = "anchor_generation_options_cc_proto",
deps = [":anchor_generation_options_proto"],
)
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package lstm_object_detection.tflite.protos;
// This is dervice from TensorFlow's SsdAnchorGenerator proto that is used to
// configures TensorFlow's anchor generator.
// object_detection/protos/ssd_anchor_generator.proto
message AnchorGenerationOptions {
// The input image width in pixels
optional int32 image_width = 1;
// The input image height in pixels
optional int32 image_height = 2;
// The base anchor width in pixels
optional int32 base_anchor_width = 3;
// The base anchor height in pixels
optional int32 base_anchor_height = 4;
// The minimum anchor scaling (should be < 1.0)
optional float min_anchor_scale = 5;
// The maximum anchor scaling
optional float max_anchor_scale = 6;
// List of aspect ratios to generate anchors for. Aspect ratio is specified as
// (width/height)
repeated float anchor_aspect_ratios = 7 [packed = true];
// List of strides in pixels for each layer
repeated int32 anchor_strides = 8 [packed = true];
// List of offset in pixels for each layer
repeated int32 anchor_offsets = 9 [packed = true];
}
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package lstm_object_detection.tflite.protos;
// The bounding box representation by center location and width/height.
// Also includes optional keypoint coordinates.
// It is a default representation in modern object detection systems.
message CenterSizeEncoding {
// Encoded anchor box center.
repeated float y = 1;
repeated float x = 2;
// Encoded anchor box height.
repeated float h = 3;
// Encoded anchor box width.
repeated float w = 4;
// Encoded keypoint coordinates.
repeated float keypoint_y = 5;
repeated float keypoint_x = 6;
}
// The scaling factors for decoding predicted offsets with CenterSizeEncoding.
// For example, given a prediction and an anchor in CenterSizeEncoding, the
// decoded location is:
// y = prediction.y / coder.y_scale() * anchor.h + anchor.y;
// x = prediction.x / coder.x_scale() * anchor.w + anchor.x;
// h = exp(prediction.h / coder.h_scale()) * anchor.h;
// w = exp(prediction.w / coder.w_scale()) * anchor.w;
// keypoint_y = prediction.keypoint_y / coder.keypoint_y_scale() * anchor.h
// + anchor.y;
// keypoint_x = prediction.keypoint_x / coder.keypoint_x_scale() * anchor.w
// + anchor.x;
// See mobile_ssd::DecodeCenterSizeBoxes for more details.
// This coder is compatible with models trained using
// object_detection.protos.FasterRcnnBoxCoder and
// object_detection.protos.KeypointBoxCoder.
message CenterSizeOffsetCoder {
// Scale factor for encoded box center offset.
optional float y_scale = 1 [default = 10.0];
optional float x_scale = 2 [default = 10.0];
// Scale factor for encoded box height offset.
optional float h_scale = 3 [default = 5.0];
// Scale factor for encoded box width offset.
optional float w_scale = 4 [default = 5.0];
// Scale factor for encoded keypoint coordinate offset.
optional float keypoint_y_scale = 5 [default = 10.0];
optional float keypoint_x_scale = 6 [default = 10.0];
}
// The canonical representation of bounding box.
message BoxCornerEncoding {
// Box corners.
repeated float ymin = 1;
repeated float xmin = 2;
repeated float ymax = 3;
repeated float xmax = 4;
// Keypoint coordinates.
repeated float keypoint_y = 5;
repeated float keypoint_x = 6;
}
// The scaling value used to adjust predicted bounding box corners.
// For example, given a prediction in BoxCornerEncoding and an anchor in
// CenterSizeEncoding, the decoded location is:
// ymin = prediction.ymin * coder.stddev + anchor.y - anchor.h / 2
// xmin = prediction.xmin * coder.stddev + anchor.x - anchor.w / 2
// ymax = prediction.ymax * coder.stddev + anchor.y + anchor.h / 2
// xmax = prediction.xmax * coder.stddev + anchor.x + anchor.w / 2
// This coder doesn't support keypoints.
// See mobile_ssd::DecodeBoxCornerBoxes for more details.
// This coder is compatible with models trained using
// object_detection.protos.MeanStddevBoxCoder.
message BoxCornerOffsetCoder {
// The standard deviation used to encode and decode boxes.
optional float stddev = 1 [default = 0.01];
}
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package lstm_object_detection.tflite.protos;
import "protos/box_encodings.proto";
// DetectionResults is a list of Detection.
message DetectionResults {
repeated Detection detection = 1;
}
// Detection consists of a bounding box, class confidences and indices.
message Detection {
// Each detection message consists of only one bounding box.
optional BoxCornerEncoding box = 1;
// A box can be associated with multiple confidences for multiple classes.
repeated float score = 2;
repeated int32 class_index = 3;
// Optional, for readability and easier access for external modules.
// A unique name that identifies the class, e.g. a MID.
repeated string class_name = 4;
// A human readable name of the class.
repeated string display_name = 5;
}
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This proto defines the labelmap used in the detection models, which maps
// the numerical class index outputs to KG mid or human readable string of
// object class name.
//
// An example labelmap looks like the following:
// item {
// name: "/m/0frqm"
// id: 1
// display_name: "Envelope"
// }
// item {
// name: "/m/02dl1y"
// id: 2
// display_name: "Hat"
// }
// item {
// name: "/m/01krhy"
// id: 3
// display_name: "Tiara"
// }
syntax = "proto2";
package lstm_object_detection.tflite.protos;
message StringIntLabelMapItem {
optional string name = 1;
optional int32 id = 2;
repeated float embedding = 3 [packed = true];
optional string display_name = 4;
// Optional list of children used to represent a hierarchy.
//
// E.g.:
//
// item {
// name: "/m/02xwb" # Fruit
// child_name: "/m/014j1m" # Apple
// child_name: "/m/0388q" # Grape
// ...
// }
// item {
// name: "/m/014j1m" # Apple
// ...
// }
repeated string child_name = 5;
}
message StringIntLabelMapProto {
repeated StringIntLabelMapItem item = 1;
}
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package lstm_object_detection.tflite.protos;
import "protos/anchor_generation_options.proto";
import "protos/box_encodings.proto";
import "protos/labelmap.proto";
// Next ID: 17
message ClientOptions {
// The name of the Mobile SSD Client.
optional string mobile_ssd_client_name = 1;
// The maximum number of detections to return.
optional uint32 max_detections = 2 [default = 10];
// The maximum number of categories to return per detection.
optional uint32 max_categories = 3 [default = 1];
// The global score threshold below which detections are rejected.
optional float score_threshold = 4 [default = 0.0];
// The threshold on intersection-over-union used by non-maxima suppression.
optional float iou_threshold = 5 [default = 0.3];
// Optional whitelist of class names. If non-empty, detections whose class
// name is not in this set will be filtered out. Duplicate or unknown class
// names are ignored.
repeated string class_name_whitelist = 6;
// SSD in single class agnostic model.
optional bool agnostic_mode = 7 [default = false];
// Fully convolutional mode, which requires on-the-fly anchor generation.
optional bool fully_conv = 8 [default = false];
// Quantized model.
optional bool quantize = 9 [default = false];
// Number of keypoints.
optional uint32 num_keypoints = 10 [default = 0];
// Optional anchor generations options. This can be used to generate
// anchors for an SSD model. It is utilized in
// MobileSSDTfLiteClient::LoadAnchors()
optional AnchorGenerationOptions anchor_generation_options = 12;
// Optional box coder specifications. This can be used for models trained
// with a customized box coder. If unspecified, it will use
// CenterSizeOffsetCoder and its default parameters.
optional BoxCoder box_coder = 13;
// The external model files used to create the detector.
// This is an alternative to registered models, where you specify external
// model via the following:
// - model using model_file_name or model_file_content
// - labelmap using label_map_file_content
// - anchors using anchor_generation_options,proto (TODO: add support for
// filename as well)
optional ExternalFiles external_files = 16;
message ExternalFiles {
// Path to the model file in FlatBuffer format.
optional string model_file_name = 1;
// Content of the model file. If provided, this takes precedence over the
// model_file_name field.
optional bytes model_file_content = 2;
// Path to the label map file.
optional string label_map_file_name = 4;
// Content of the label map file. If provided, this takes precedence over
// the label_map_file_name field.
optional bytes label_map_file_content = 3;
// Path to the anchor file.
optional string anchor_file_name = 5;
// Content of the anchor file. If provided, this takes precedence over
// the anchor_file_name field.
optional bytes anchor_file_content = 6;
}
// Whether to use NNAPI delegate for hardware acceleration.
// If it fails, it will fall back to the normal CPU execution.
optional bool prefer_nnapi_delegate = 14;
// Number of threads to be used by TFlite interpreter for SSD inference. Does
// single-threaded inference by default.
optional int32 num_threads = 15 [default = 1];
extensions 1000 to max;
}
message BoxCoder {
oneof box_coder_oneof {
CenterSizeOffsetCoder center_size_offset_coder = 1;
BoxCornerOffsetCoder box_corner_offset_coder = 2;
}
}
message ModelData {
oneof source {
string model_file = 1;
bytes embedded_model = 2;
}
}
# This file necessary for Portable Proto library
allow_all: true
# Other configuration options:
optimize_mode: LITE_RUNTIME
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment