"magic_pdf/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "78bdf53ee429c5c8a9d791e4f64bc0f0c8d26f03"
Commit 13e7c85d authored by Yongzhe Wang's avatar Yongzhe Wang Committed by Menglong Zhu
Browse files

Merged commit includes the following changes: (#7358)

261196859  by yongzhe:

    Integrate EdgeTPU API into the Mobile SSD tflite client.

    Build command with EdgeTPU enabled:
    bazel build mobile_ssd_tflite_client  --define enable_edgetpu=true

    Build command with EdgeTPU disabled:
    bazel build mobile_ssd_tflite_client

--
259096620  by Menglong Zhu:

    Remove unused proto imports.

--

PiperOrigin-RevId: 261196859
parent aee49bbd
...@@ -23,22 +23,35 @@ cc_library( ...@@ -23,22 +23,35 @@ cc_library(
], ],
) )
config_setting(
name = "enable_edgetpu",
define_values = {"enable_edgetpu": "true"},
visibility = ["//visibility:public"],
)
cc_library( cc_library(
name = "mobile_ssd_tflite_client", name = "mobile_ssd_tflite_client",
srcs = ["mobile_ssd_tflite_client.cc"], srcs = ["mobile_ssd_tflite_client.cc"],
hdrs = ["mobile_ssd_tflite_client.h"], hdrs = ["mobile_ssd_tflite_client.h"],
defines = select({
"//conditions:default": [],
"enable_edgetpu": ["ENABLE_EDGETPU"],
}),
deps = [ deps = [
":mobile_ssd_client", ":mobile_ssd_client",
"//protos:anchor_generation_options_cc_proto",
"//utils:file_utils",
"//utils:ssd_utils",
"@com_google_absl//absl/memory",
"@com_google_glog//:glog", "@com_google_glog//:glog",
"@com_google_absl//absl/memory",
"@org_tensorflow//tensorflow/lite:arena_planner", "@org_tensorflow//tensorflow/lite:arena_planner",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate", "@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
], "//protos:anchor_generation_options_cc_proto",
"//utils:file_utils",
"//utils:ssd_utils",
] + select({
"//conditions:default": [],
"enable_edgetpu": ["@libedgetpu//libedgetpu:header"],
}),
alwayslink = 1, alwayslink = 1,
) )
...@@ -51,6 +64,7 @@ cc_library( ...@@ -51,6 +64,7 @@ cc_library(
":mobile_ssd_tflite_client", ":mobile_ssd_tflite_client",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_glog//:glog", "@com_google_glog//:glog",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
], ],
alwayslink = 1, alwayslink = 1,
) )
workspace(name = "lstm_object_detection") workspace(name = "lstm_object_detection")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
http_archive( http_archive(
name = "bazel_skylib", name = "bazel_skylib",
...@@ -118,3 +119,9 @@ http_archive( ...@@ -118,3 +119,9 @@ http_archive(
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
tf_workspace(tf_repo_name = "org_tensorflow") tf_workspace(tf_repo_name = "org_tensorflow")
git_repository(
name = "libedgetpu",
remote = "sso://coral.googlesource.com/edgetpu-native",
commit = "83e47d1bcf22686fae5150ebb99281f6134ef062",
)
...@@ -37,7 +37,7 @@ constexpr int GetScoreIndex(const int layer) { return (2 * layer + 1); } ...@@ -37,7 +37,7 @@ constexpr int GetScoreIndex(const int layer) { return (2 * layer + 1); }
MobileSSDTfLiteClient::MobileSSDTfLiteClient() {} MobileSSDTfLiteClient::MobileSSDTfLiteClient() {}
std::unique_ptr<::tflite::OpResolver> std::unique_ptr<::tflite::MutableOpResolver>
MobileSSDTfLiteClient::CreateOpResolver() { MobileSSDTfLiteClient::CreateOpResolver() {
return absl::make_unique<::tflite::ops::builtin::BuiltinOpResolver>(); return absl::make_unique<::tflite::ops::builtin::BuiltinOpResolver>();
} }
...@@ -77,6 +77,12 @@ bool MobileSSDTfLiteClient::InitializeClient( ...@@ -77,6 +77,12 @@ bool MobileSSDTfLiteClient::InitializeClient(
resolver_ = CreateOpResolver(); resolver_ = CreateOpResolver();
#ifdef ENABLE_EDGETPU
edge_tpu_context_ =
edgetpu::EdgeTpuManager::GetSingleton()->NewEdgeTpuContext();
resolver_->AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp());
#endif
::tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter_); ::tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter_);
if (!interpreter_) { if (!interpreter_) {
LOG(ERROR) << "Failed to build interpreter"; LOG(ERROR) << "Failed to build interpreter";
...@@ -178,6 +184,11 @@ bool MobileSSDTfLiteClient::InitializeInterpreter( ...@@ -178,6 +184,11 @@ bool MobileSSDTfLiteClient::InitializeInterpreter(
} }
interpreter_->UseNNAPI(false); interpreter_->UseNNAPI(false);
#ifdef ENABLE_EDGETPU
interpreter_->SetExternalContext(kTfLiteEdgeTpuContext,
edge_tpu_context_.get());
#endif
if (options.num_threads() > 0) { if (options.num_threads() > 0) {
interpreter_->SetNumThreads(options.num_threads()); interpreter_->SetNumThreads(options.num_threads());
} }
......
...@@ -25,6 +25,9 @@ limitations under the License. ...@@ -25,6 +25,9 @@ limitations under the License.
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
#include "mobile_ssd_client.h" #include "mobile_ssd_client.h"
#include "protos/anchor_generation_options.pb.h" #include "protos/anchor_generation_options.pb.h"
#ifdef ENABLE_EDGETPU
#include "libedgetpu/libedgetpu.h"
#endif // ENABLE_EDGETPU
namespace lstm_object_detection { namespace lstm_object_detection {
namespace tflite { namespace tflite {
...@@ -40,7 +43,7 @@ class MobileSSDTfLiteClient : public MobileSSDClient { ...@@ -40,7 +43,7 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
// By default CreateOpResolver will create // By default CreateOpResolver will create
// tflite::ops::builtin::BuiltinOpResolver. Overriding the function allows the // tflite::ops::builtin::BuiltinOpResolver. Overriding the function allows the
// client to use custom op resolvers. // client to use custom op resolvers.
virtual std::unique_ptr<::tflite::OpResolver> CreateOpResolver(); virtual std::unique_ptr<::tflite::MutableOpResolver> CreateOpResolver();
bool InitializeClient(const protos::ClientOptions& options) override; bool InitializeClient(const protos::ClientOptions& options) override;
...@@ -70,7 +73,7 @@ class MobileSSDTfLiteClient : public MobileSSDClient { ...@@ -70,7 +73,7 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
virtual bool IsQuantizedModel() const; virtual bool IsQuantizedModel() const;
std::unique_ptr<::tflite::FlatBufferModel> model_; std::unique_ptr<::tflite::FlatBufferModel> model_;
std::unique_ptr<::tflite::OpResolver> resolver_; std::unique_ptr<::tflite::MutableOpResolver> resolver_;
std::unique_ptr<::tflite::Interpreter> interpreter_; std::unique_ptr<::tflite::Interpreter> interpreter_;
private: private:
...@@ -100,6 +103,10 @@ class MobileSSDTfLiteClient : public MobileSSDClient { ...@@ -100,6 +103,10 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
bool FloatInference(const uint8_t* input_data); bool FloatInference(const uint8_t* input_data);
bool QuantizedInference(const uint8_t* input_data); bool QuantizedInference(const uint8_t* input_data);
void GetOutputBoxesAndScoreTensorsFromUInt8(); void GetOutputBoxesAndScoreTensorsFromUInt8();
#ifdef ENABLE_EDGETPU
std::unique_ptr<edgetpu::EdgeTpuContext> edge_tpu_context_;
#endif
}; };
} // namespace tflite } // namespace tflite
......
...@@ -19,7 +19,6 @@ package lstm_object_detection.tflite.protos; ...@@ -19,7 +19,6 @@ package lstm_object_detection.tflite.protos;
import "protos/anchor_generation_options.proto"; import "protos/anchor_generation_options.proto";
import "protos/box_encodings.proto"; import "protos/box_encodings.proto";
import "protos/labelmap.proto";
// Next ID: 17 // Next ID: 17
message ClientOptions { message ClientOptions {
...@@ -55,7 +54,6 @@ message ClientOptions { ...@@ -55,7 +54,6 @@ message ClientOptions {
// Number of keypoints. // Number of keypoints.
optional uint32 num_keypoints = 10 [default = 0]; optional uint32 num_keypoints = 10 [default = 0];
// Optional anchor generations options. This can be used to generate // Optional anchor generations options. This can be used to generate
// anchors for an SSD model. It is utilized in // anchors for an SSD model. It is utilized in
// MobileSSDTfLiteClient::LoadAnchors() // MobileSSDTfLiteClient::LoadAnchors()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment