Commit b7221961 authored by Yongzhe Wang's avatar Yongzhe Wang Committed by Menglong Zhu
Browse files

Merged commit includes the following changes: (#7249)

* Merged commit includes the following changes:
257930561  by yongzhe:

    Mobile LSTD TfLite Client.

--
257928126  by yongzhe:

    Mobile SSD Tflite client.

--
257921181  by menglong:

    Fix discrepancy between pre_bottleneck = {true, false}

--
257561213  by yongzhe:

    File utils.

--
257449226  by yongzhe:

    Mobile SSD Client.

--
257264654  by yongzhe:

    SSD utils.

--
257235648  by yongzhe:

    Proto bazel build rules.

--
256437262  by Menglong Zhu:

    Fix check for FusedBatchNorm op to only verify it as a prefix.

--
256283755  by yongzhe:

    Bazel build and copybara changes.

--
251947295  by yinxiao:

    Add missing interleaved option in checkpoint restore.

--
251513479  by yongzhe:

    Conversion utils.

--
248783193  by yongzhe:

    Branch protos needed for the lstd client.

--
248200507  by menglong:

    Fix proto namespace in example config

--

P...
parent e21dcdd0
...@@ -98,9 +98,13 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter( ...@@ -98,9 +98,13 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
return false; return false;
} }
// Outputs are: raw_outputs/box_encodings, raw_outputs/class_predictions, // Outputs are:
// TFLite_Detection_PostProcess,
// TFLite_Detection_PostProcess:1,
// TFLite_Detection_PostProcess:2,
// TFLite_Detection_PostProcess:3,
// raw_outputs/lstm_c, raw_outputs/lstm_h // raw_outputs/lstm_c, raw_outputs/lstm_h
if (interpreter_->outputs().size() != 4) { if (interpreter_->outputs().size() != 6) {
LOG(ERROR) << "Invalid number of interpreter outputs: " << LOG(ERROR) << "Invalid number of interpreter outputs: " <<
interpreter_->outputs().size(); interpreter_->outputs().size();
return false; return false;
...@@ -108,12 +112,12 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter( ...@@ -108,12 +112,12 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
const std::vector<int> output_tensor_indices = interpreter_->outputs(); const std::vector<int> output_tensor_indices = interpreter_->outputs();
const TfLiteTensor& output_lstm_c = const TfLiteTensor& output_lstm_c =
*interpreter_->tensor(output_tensor_indices[2]); *interpreter_->tensor(output_tensor_indices[4]);
if (!ValidateStateTensor(output_lstm_c, "output lstm_c")) { if (!ValidateStateTensor(output_lstm_c, "output lstm_c")) {
return false; return false;
} }
const TfLiteTensor& output_lstm_h = const TfLiteTensor& output_lstm_h =
*interpreter_->tensor(output_tensor_indices[3]); *interpreter_->tensor(output_tensor_indices[5]);
if (!ValidateStateTensor(output_lstm_h, "output lstm_h")) { if (!ValidateStateTensor(output_lstm_h, "output lstm_h")) {
return false; return false;
} }
...@@ -121,6 +125,8 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter( ...@@ -121,6 +125,8 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
// Initialize state with all zeroes. // Initialize state with all zeroes.
lstm_c_data_.resize(lstm_state_size_); lstm_c_data_.resize(lstm_state_size_);
lstm_h_data_.resize(lstm_state_size_); lstm_h_data_.resize(lstm_state_size_);
lstm_c_data_uint8_.resize(lstm_state_size_);
lstm_h_data_uint8_.resize(lstm_state_size_);
if (interpreter_->AllocateTensors() != kTfLiteOk) { if (interpreter_->AllocateTensors() != kTfLiteOk) {
LOG(ERROR) << "Failed to allocate tensors"; LOG(ERROR) << "Failed to allocate tensors";
...@@ -201,6 +207,7 @@ bool MobileLSTDTfLiteClient::QuantizedInference(const uint8_t* input_data) { ...@@ -201,6 +207,7 @@ bool MobileLSTDTfLiteClient::QuantizedInference(const uint8_t* input_data) {
CHECK(input_data) << "Input data cannot be null."; CHECK(input_data) << "Input data cannot be null.";
uint8_t* input = interpreter_->typed_input_tensor<uint8_t>(0); uint8_t* input = interpreter_->typed_input_tensor<uint8_t>(0);
CHECK(input) << "Input tensor cannot be null."; CHECK(input) << "Input tensor cannot be null.";
memcpy(input, input_data, input_size_);
// Copy input LSTM state into TFLite's input tensors. // Copy input LSTM state into TFLite's input tensors.
uint8_t* lstm_c_input = interpreter_->typed_input_tensor<uint8_t>(1); uint8_t* lstm_c_input = interpreter_->typed_input_tensor<uint8_t>(1);
...@@ -215,14 +222,18 @@ bool MobileLSTDTfLiteClient::QuantizedInference(const uint8_t* input_data) { ...@@ -215,14 +222,18 @@ bool MobileLSTDTfLiteClient::QuantizedInference(const uint8_t* input_data) {
CHECK_EQ(interpreter_->Invoke(), kTfLiteOk) << "Invoking interpreter failed."; CHECK_EQ(interpreter_->Invoke(), kTfLiteOk) << "Invoking interpreter failed.";
// Copy LSTM state out of TFLite's output tensors. // Copy LSTM state out of TFLite's output tensors.
// Outputs are: raw_outputs/box_encodings, raw_outputs/class_predictions, // Outputs are:
// TFLite_Detection_PostProcess,
// TFLite_Detection_PostProcess:1,
// TFLite_Detection_PostProcess:2,
// TFLite_Detection_PostProcess:3,
// raw_outputs/lstm_c, raw_outputs/lstm_h // raw_outputs/lstm_c, raw_outputs/lstm_h
uint8_t* lstm_c_output = interpreter_->typed_output_tensor<uint8_t>(2); uint8_t* lstm_c_output = interpreter_->typed_output_tensor<uint8_t>(4);
CHECK(lstm_c_output) << "Output lstm_c tensor cannot be null."; CHECK(lstm_c_output) << "Output lstm_c tensor cannot be null.";
std::copy(lstm_c_output, lstm_c_output + lstm_state_size_, std::copy(lstm_c_output, lstm_c_output + lstm_state_size_,
lstm_c_data_uint8_.begin()); lstm_c_data_uint8_.begin());
uint8_t* lstm_h_output = interpreter_->typed_output_tensor<uint8_t>(3); uint8_t* lstm_h_output = interpreter_->typed_output_tensor<uint8_t>(5);
CHECK(lstm_h_output) << "Output lstm_h tensor cannot be null."; CHECK(lstm_h_output) << "Output lstm_h tensor cannot be null.";
std::copy(lstm_h_output, lstm_h_output + lstm_state_size_, std::copy(lstm_h_output, lstm_h_output + lstm_state_size_,
lstm_h_data_uint8_.begin()); lstm_h_data_uint8_.begin());
......
...@@ -76,7 +76,7 @@ bool MobileSSDClient::BatchDetect( ...@@ -76,7 +76,7 @@ bool MobileSSDClient::BatchDetect(
LOG(ERROR) << "Post Processing not supported."; LOG(ERROR) << "Post Processing not supported.";
return false; return false;
} else { } else {
if (NoPostProcessNoAnchors(detections[batch])) { if (!NoPostProcessNoAnchors(detections[batch])) {
LOG(ERROR) << "NoPostProcessNoAnchors failed."; LOG(ERROR) << "NoPostProcessNoAnchors failed.";
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment