Commit 901c4cc4 authored by Vinh Nguyen's avatar Vinh Nguyen
Browse files

Merge remote-tracking branch 'upstream/master' into amp_resnet50

parents ef30de93 824ff2d6
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test a tflite model using random input data."""
from __future__ import print_function
from absl import flags
import numpy as np
import tensorflow as tf
flags.DEFINE_string('model_path', None, 'Path to model.')
FLAGS = flags.FLAGS
def main(_):
flags.mark_flag_as_required('model_path')
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=FLAGS.model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
print('input_details:', input_details)
output_details = interpreter.get_output_details()
print('output_details:', output_details)
# Test model on random input data.
input_shape = input_details[0]['shape']
# change the following line to feed into your own data.
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
if __name__ == '__main__':
tf.app.run()
......@@ -59,12 +59,19 @@ cc_library(
name = "mobile_lstd_tflite_client",
srcs = ["mobile_lstd_tflite_client.cc"],
hdrs = ["mobile_lstd_tflite_client.h"],
defines = select({
"//conditions:default": [],
"enable_edgetpu": ["ENABLE_EDGETPU"],
}),
deps = [
":mobile_ssd_client",
":mobile_ssd_tflite_client",
"@com_google_absl//absl/base:core_headers",
"@com_google_glog//:glog",
"@com_google_absl//absl/base:core_headers",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
] + select({
"//conditions:default": [],
"enable_edgetpu": ["@libedgetpu//libedgetpu:header"],
}),
alwayslink = 1,
)
......@@ -90,13 +90,6 @@ http_archive(
sha256 = "79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc",
)
#
# http_archive(
# name = "com_google_protobuf",
# strip_prefix = "protobuf-master",
# urls = ["https://github.com/protocolbuffers/protobuf/archive/master.zip"],
# )
# Needed by TensorFlow
http_archive(
name = "io_bazel_rules_closure",
......
......@@ -66,6 +66,11 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
interpreter_->UseNNAPI(false);
}
#ifdef ENABLE_EDGETPU
interpreter_->SetExternalContext(kTfLiteEdgeTpuContext,
edge_tpu_context_.get());
#endif
// Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
// raw_inputs/init_lstm_h
if (interpreter_->inputs().size() != 3) {
......
......@@ -26,7 +26,7 @@ limitations under the License.
#include "mobile_ssd_client.h"
#include "protos/anchor_generation_options.pb.h"
#ifdef ENABLE_EDGETPU
#include "libedgetpu/libedgetpu.h"
#include "libedgetpu/edgetpu.h"
#endif // ENABLE_EDGETPU
namespace lstm_object_detection {
......@@ -76,6 +76,10 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
std::unique_ptr<::tflite::MutableOpResolver> resolver_;
std::unique_ptr<::tflite::Interpreter> interpreter_;
#ifdef ENABLE_EDGETPU
std::unique_ptr<edgetpu::EdgeTpuContext> edge_tpu_context_;
#endif
private:
// MobileSSDTfLiteClient is neither copyable nor movable.
MobileSSDTfLiteClient(const MobileSSDTfLiteClient&) = delete;
......@@ -103,10 +107,6 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
bool FloatInference(const uint8_t* input_data);
bool QuantizedInference(const uint8_t* input_data);
void GetOutputBoxesAndScoreTensorsFromUInt8();
#ifdef ENABLE_EDGETPU
std::unique_ptr<edgetpu::EdgeTpuContext> edge_tpu_context_;
#endif
};
} // namespace tflite
......
......@@ -37,7 +37,7 @@ def get_configs_from_pipeline_file(pipeline_config_path):
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_confg`.
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_model`.
Value are the corresponding config objects.
"""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment