Commit b7221961 authored by Yongzhe Wang's avatar Yongzhe Wang Committed by Menglong Zhu
Browse files

Merged commit includes the following changes: (#7249)

* Merged commit includes the following changes:
257930561  by yongzhe:

    Mobile LSTD TfLite Client.

--
257928126  by yongzhe:

    Mobile SSD Tflite client.

--
257921181  by menglong:

    Fix discrepancy between pre_bottleneck = {true, false}

--
257561213  by yongzhe:

    File utils.

--
257449226  by yongzhe:

    Mobile SSD Client.

--
257264654  by yongzhe:

    SSD utils.

--
257235648  by yongzhe:

    Proto bazel build rules.

--
256437262  by Menglong Zhu:

    Fix check for FusedBatchNorm op to only verify it as a prefix.

--
256283755  by yongzhe:

    Bazel build and copybara changes.

--
251947295  by yinxiao:

    Add missing interleaved option in checkpoint restore.

--
251513479  by yongzhe:

    Conversion utils.

--
248783193  by yongzhe:

    Branch protos needed for the lstd client.

--
248200507  by menglong:

    Fix proto namespace in example config

--

P...
parent e21dcdd0
......@@ -98,9 +98,13 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
return false;
}
// Outputs are: raw_outputs/box_encodings, raw_outputs/class_predictions,
// Outputs are:
// TFLite_Detection_PostProcess,
// TFLite_Detection_PostProcess:1,
// TFLite_Detection_PostProcess:2,
// TFLite_Detection_PostProcess:3,
// raw_outputs/lstm_c, raw_outputs/lstm_h
if (interpreter_->outputs().size() != 4) {
if (interpreter_->outputs().size() != 6) {
LOG(ERROR) << "Invalid number of interpreter outputs: " <<
interpreter_->outputs().size();
return false;
......@@ -108,12 +112,12 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
const std::vector<int> output_tensor_indices = interpreter_->outputs();
const TfLiteTensor& output_lstm_c =
*interpreter_->tensor(output_tensor_indices[2]);
*interpreter_->tensor(output_tensor_indices[4]);
if (!ValidateStateTensor(output_lstm_c, "output lstm_c")) {
return false;
}
const TfLiteTensor& output_lstm_h =
*interpreter_->tensor(output_tensor_indices[3]);
*interpreter_->tensor(output_tensor_indices[5]);
if (!ValidateStateTensor(output_lstm_h, "output lstm_h")) {
return false;
}
......@@ -121,6 +125,8 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
// Initialize state with all zeroes.
lstm_c_data_.resize(lstm_state_size_);
lstm_h_data_.resize(lstm_state_size_);
lstm_c_data_uint8_.resize(lstm_state_size_);
lstm_h_data_uint8_.resize(lstm_state_size_);
if (interpreter_->AllocateTensors() != kTfLiteOk) {
LOG(ERROR) << "Failed to allocate tensors";
......@@ -201,6 +207,7 @@ bool MobileLSTDTfLiteClient::QuantizedInference(const uint8_t* input_data) {
CHECK(input_data) << "Input data cannot be null.";
uint8_t* input = interpreter_->typed_input_tensor<uint8_t>(0);
CHECK(input) << "Input tensor cannot be null.";
memcpy(input, input_data, input_size_);
// Copy input LSTM state into TFLite's input tensors.
uint8_t* lstm_c_input = interpreter_->typed_input_tensor<uint8_t>(1);
......@@ -215,14 +222,18 @@ bool MobileLSTDTfLiteClient::QuantizedInference(const uint8_t* input_data) {
CHECK_EQ(interpreter_->Invoke(), kTfLiteOk) << "Invoking interpreter failed.";
// Copy LSTM state out of TFLite's output tensors.
// Outputs are: raw_outputs/box_encodings, raw_outputs/class_predictions,
// Outputs are:
// TFLite_Detection_PostProcess,
// TFLite_Detection_PostProcess:1,
// TFLite_Detection_PostProcess:2,
// TFLite_Detection_PostProcess:3,
// raw_outputs/lstm_c, raw_outputs/lstm_h
uint8_t* lstm_c_output = interpreter_->typed_output_tensor<uint8_t>(2);
uint8_t* lstm_c_output = interpreter_->typed_output_tensor<uint8_t>(4);
CHECK(lstm_c_output) << "Output lstm_c tensor cannot be null.";
std::copy(lstm_c_output, lstm_c_output + lstm_state_size_,
lstm_c_data_uint8_.begin());
uint8_t* lstm_h_output = interpreter_->typed_output_tensor<uint8_t>(3);
uint8_t* lstm_h_output = interpreter_->typed_output_tensor<uint8_t>(5);
CHECK(lstm_h_output) << "Output lstm_h tensor cannot be null.";
std::copy(lstm_h_output, lstm_h_output + lstm_state_size_,
lstm_h_data_uint8_.begin());
......
......@@ -76,7 +76,7 @@ bool MobileSSDClient::BatchDetect(
LOG(ERROR) << "Post Processing not supported.";
return false;
} else {
if (NoPostProcessNoAnchors(detections[batch])) {
if (!NoPostProcessNoAnchors(detections[batch])) {
LOG(ERROR) << "NoPostProcessNoAnchors failed.";
return false;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment