Unverified Commit 95d1d067 authored by karun's avatar karun Committed by GitHub
Browse files

tflite handlers for the custom ops (#10736)


Co-authored-by: default avatarArun Kandoor <akandoor@google.com>
parent 0028cbed
......@@ -114,3 +114,89 @@ pybind_extension(
"@pybind11",
],
)
cc_library(
name = "tflite_qrnn_pooling",
srcs = ["tflite_qrnn_pooling.cc"],
hdrs = ["tflite_qrnn_pooling.h"],
copts = tflite_copts(),
deps = [
"//third_party/absl/base:core_headers",
"//third_party/tensorflow/lite/kernels:builtin_ops",
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util",
],
alwayslink = 1,
)
cc_library(
name = "tflite_decoder_cache",
hdrs = ["tflite_decoder_cache.h"],
deps = [
"//third_party/tensorflow/lite/c:common",
],
alwayslink = 1,
)
cc_library(
name = "tflite_decoder_handler",
srcs = ["tflite_decoder_handler.cc"],
hdrs = ["tflite_decoder_handler.h"],
copts = tflite_copts(),
deps = [
":tflite_decoder_cache",
"//third_party/flatbuffers",
"//third_party/tensorflow/lite/c:common",
"//third_party/tensorflow/lite/kernels:builtin_ops",
"//third_party/tensorflow/lite/kernels:kernel_util",
"//third_party/tensorflow/lite/kernels/internal:tensor",
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util",
],
alwayslink = 1,
)
cc_test(
name = "tflite_decoder_handler_test",
size = "small",
srcs = ["tflite_decoder_handler_test.cc"],
deps = [
":tflite_decoder_handler",
"//testing/base/public:gunit",
"//third_party/flatbuffers",
"//third_party/tensorflow/lite:framework",
"//third_party/tensorflow/lite/c:common",
"//third_party/tensorflow/lite/kernels:test_util",
],
)
cc_library(
name = "beam_search",
srcs = ["beam_search.cc"],
hdrs = ["beam_search.h"],
copts = tflite_copts(),
deps = [
"//base",
"//third_party/absl/strings",
"//third_party/tensorflow/lite/c:common",
"//third_party/tensorflow/lite/kernels/internal:tensor",
"//third_party/tensorflow/lite/kernels/internal:types",
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util",
],
)
cc_test(
name = "beam_search_test",
srcs = ["beam_search_test.cc"],
copts = tflite_copts(),
deps = [
":beam_search",
"//testing/base/public:gunit_main",
"//third_party/absl/strings",
"//third_party/tensorflow/lite/c:c_api_types",
"//third_party/tensorflow/lite/c:common",
"//third_party/tensorflow/lite/kernels/internal:legacy_reference_base",
"//third_party/tensorflow/lite/kernels/internal:optimized_base",
"//third_party/tensorflow/lite/kernels/internal:tensor",
"//third_party/tensorflow/lite/kernels/internal:types",
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util",
],
)
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/beam_search.h"
#include <algorithm>
#include <cstdint>
#include <numeric>
#include <vector>
#include "base/logging.h"
#include "third_party/absl/strings/str_join.h"
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/internal/types.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
namespace seq_flow_lite {
namespace ops {
namespace custom {
namespace {
constexpr int kLeftShiftNumBits = 24;
constexpr int kClassIndexMask = (1 << kLeftShiftNumBits) - 1;
// Tracks finished sequences within the beams.
class SequenceTracker {
public:
explicit SequenceTracker(int beam_size, int eos_id)
: beam_size_(beam_size),
eos_id_(eos_id),
min_terminated_scores_(-kInfinite) {}
void AddSequence(const int32_t *begin, const int32_t *end, float score);
int NumSequences() { return terminated_topk_.size(); }
std::vector<std::vector<int32_t>> GetTopBeams();
float MinTrackedScore() { return min_terminated_scores_; }
float MaxTrackedScore() {
return terminated_topk_.empty() ? -kInfinite
: terminated_topk_.begin()->first;
}
private:
static constexpr float kInfinite = 1e7;
const int beam_size_;
const int eos_id_;
// TODO(akandoor): Consider using std::vector and heap accessors instead.
std::map<float, std::vector<int32_t>, std::greater<float>> terminated_topk_;
float min_terminated_scores_;
};
void PrintBeam(const int32 *array_new, int cur_step) {
LOG(INFO) << absl::StrJoin(array_new, array_new + cur_step, ", ");
}
bool HeapCompare(std::pair<float, int> &a, std::pair<float, int> &b) {
return a.first > b.first;
}
} // namespace
void SequenceTracker::AddSequence(const int32_t *begin, const int32_t *end,
float score) {
if (NumSequences() < beam_size_ || score > min_terminated_scores_) {
// TODO(akandoor): Handle duplicate scores.
if (NumSequences() >= beam_size_) {
terminated_topk_.erase(std::prev(terminated_topk_.end()));
}
// TODO(prabhumk): This can potentially slow things down. Fix this.
terminated_topk_[score] = std::vector<int32_t>(begin, end);
// Pushing EOS_ID to terminate the sequence.
terminated_topk_[score].push_back(eos_id_);
min_terminated_scores_ = terminated_topk_.rbegin()->first;
}
}
std::vector<std::vector<int32_t>> SequenceTracker::GetTopBeams() {
std::vector<std::vector<int32_t>> return_value;
for (const auto &v : terminated_topk_) {
return_value.push_back(v.second);
}
return return_value;
}
void BeamSearch::PopulateLogLookupTable(const TfLiteTensor &tensor) {
if (!log_lookup_table_populated_) {
for (int value = 0; value < 256; ++value) {
log_lookup_table_[value] =
logf(::seq_flow_lite::PodDequantizeValue(tensor, value));
}
log_lookup_table_populated_ = true;
}
}
void BeamSearch::PopulateSoftmaxLookupTable(const TfLiteTensor &tensor) {
if (!exp_lookup_table_populated_) {
const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
for (int32_t val = 0; val <= max_uint8; ++val) {
exp_lookup_table_[max_uint8 - val] = expf(-tensor.params.scale * val);
}
exp_lookup_table_populated_ = true;
}
}
float BeamSearch::InverseLengthPenalty(int step) {
return 1.0f / std::powf((5.f + step) / 6.f, alpha_);
}
void BeamSearch::FindTopKFloat(const TfLiteTensor &tensor, int valid_beams,
int K) {
topk_heap_.clear();
const float *probabilities = ::tflite::GetTensorData<float>(&tensor);
for (int j = 0; j < valid_beams; ++j) {
for (int k = 0; k < num_classes_; ++k) {
const int index = j * num_classes_ + k;
float log_probs =
(beam_log_probabilities_[j] + logf(probabilities[index]));
topk_heap_.push_back(std::pair<float, int>(log_probs, index));
std::push_heap(topk_heap_.begin(), topk_heap_.end(), HeapCompare);
if (topk_heap_.size() > K) {
std::pop_heap(topk_heap_.begin(), topk_heap_.end(), HeapCompare);
topk_heap_.pop_back();
}
}
}
std::sort(topk_heap_.begin(), topk_heap_.end(), HeapCompare);
}
void BeamSearch::FindTopKQuantized(const TfLiteTensor &tensor, int valid_beams,
int K) {
PopulateLogLookupTable(tensor);
topk_heap_.clear();
const uint8 *probabilities = ::tflite::GetTensorData<uint8_t>(&tensor);
for (int j = 0; j < valid_beams; ++j) {
for (int k = 0; k < num_classes_; ++k) {
const int index = j * num_classes_ + k;
const float log_probs = (beam_log_probabilities_[j] +
log_lookup_table_[probabilities[index]]);
topk_heap_.push_back(std::pair<float, int>(log_probs, index));
std::push_heap(topk_heap_.begin(), topk_heap_.end(), HeapCompare);
if (topk_heap_.size() > K) {
std::pop_heap(topk_heap_.begin(), topk_heap_.end(), HeapCompare);
topk_heap_.pop_back();
}
}
}
std::sort(topk_heap_.begin(), topk_heap_.end(), HeapCompare);
}
void BeamSearch::SetMaskForLogits(const std::vector<bool> &mask) {
logits_mask_.assign(mask.begin(), mask.end());
CHECK_EQ(logits_mask_.size(), num_classes_)
<< "Mask size should be same as num_classes";
}
void BeamSearch::FindTopKQuantizedFromLogits(const TfLiteTensor &tensor,
int valid_beams, int topk_k) {
PopulateSoftmaxLookupTable(tensor);
topk_heap_.clear();
const uint8_t *logits = ::tflite::GetTensorData<uint8_t>(&tensor);
for (int j = 0; j < valid_beams; ++j) {
const uint8_t *beam_logits = logits + j * num_classes_;
uint8_t max_val = std::numeric_limits<uint8_t>::min();
// Finding max quantized value in the current beam.
for (int k = 0; k < num_classes_; ++k) {
if (!logits_mask_[k]) continue;
max_val = std::max(max_val, beam_logits[k]);
}
float sum_exp = 0.0f;
const int32_t max_uint8 = std::numeric_limits<uint8>::max();
// Offset into table to compute exp(scale*(x - xmax)) instead of
// exp(scale*(x)) to prevent overflow.
const float *table_offset = &exp_lookup_table_[max_uint8 - max_val];
// Calculate sum(exp(scale*(x - x_max))).
for (int k = 0; k < num_classes_; ++k) {
if (!logits_mask_[k]) continue;
sum_exp += table_offset[beam_logits[k]];
}
CHECK(sum_exp) << "Invalid logits or Mask provided.";
const float log_sum_exp = std::log(sum_exp);
const float precomputed = (tensor.params.scale * max_val + log_sum_exp);
for (int k = 0; k < num_classes_; ++k) {
if (!logits_mask_[k]) continue;
const int index = j * num_classes_ + k;
const float log_prob = tensor.params.scale * beam_logits[k] - precomputed;
const float beam_log_prob = (beam_log_probabilities_[j] + log_prob);
topk_heap_.push_back(std::pair<float, int>(beam_log_prob, index));
std::push_heap(topk_heap_.begin(), topk_heap_.end(), HeapCompare);
if (topk_heap_.size() > topk_k) {
std::pop_heap(topk_heap_.begin(), topk_heap_.end(), HeapCompare);
topk_heap_.pop_back();
}
}
}
std::sort(topk_heap_.begin(), topk_heap_.end(), HeapCompare);
}
void BeamSearch::FindTopKQuantizedFromLogitsV1(const TfLiteTensor &tensor,
int valid_beams, int topk_k) {
PopulateSoftmaxLookupTable(tensor);
topk_heap_.clear();
std::vector<uint32_t> curr_beam_topk(topk_k);
const uint8 *logits = ::tflite::GetTensorData<uint8_t>(&tensor);
for (int j = 0; j < valid_beams; ++j) {
// Resetting the topk logits vector for each beam.
curr_beam_topk.clear();
const uint8_t *beam_logits = logits + j * num_classes_;
uint8_t max_val = std::numeric_limits<uint8_t>::min();
// Finding max quantized value in the current beam.
for (int k = 0; k < num_classes_; ++k) {
if (!logits_mask_[k]) continue;
max_val = std::max(max_val, beam_logits[k]);
}
float sum_exp = 0.0f;
const int32_t max_uint8 = std::numeric_limits<uint8>::max();
// Offset into table to compute exp(scale*(x - xmax)) instead of
// exp(scale*(x)) to prevent overflow.
const float *table_offset = &exp_lookup_table_[max_uint8 - max_val];
// Calculate sum(exp(scale*(x - x_max))).
for (int k = 0; k < num_classes_; ++k) {
if (!logits_mask_[k]) continue;
sum_exp += table_offset[beam_logits[k]];
}
CHECK(sum_exp) << "Invalid logits or mask provided.";
const float log_sum_exp = std::log(sum_exp);
const float precomputed = (tensor.params.scale * max_val + log_sum_exp);
// Computing indices for topk logits in the current beam.
for (uint32_t k = 0; k < num_classes_; ++k) {
if (!logits_mask_[k]) continue;
// Pushing logits uint8 value to MSB and storing index in the 24 LSB.
const uint32_t val =
(beam_logits[k] << kLeftShiftNumBits) | (k & kClassIndexMask);
curr_beam_topk.push_back(val);
std::push_heap(curr_beam_topk.begin(), curr_beam_topk.end(),
std::greater<>());
if (curr_beam_topk.size() > topk_k) {
std::pop_heap(curr_beam_topk.begin(), curr_beam_topk.end(),
std::greater<>());
curr_beam_topk.pop_back();
}
}
// Updating topk across all beams.
for (uint32_t k = 0; k < std::min(topk_k, num_classes_); ++k) {
const uint32_t curr_beam_index = curr_beam_topk[k] & kClassIndexMask;
const uint32_t index = j * num_classes_ + curr_beam_index;
const float log_prob =
tensor.params.scale * beam_logits[curr_beam_index] - precomputed;
const float beam_log_prob = (beam_log_probabilities_[j] + log_prob);
topk_heap_.push_back(std::pair<float, int>(beam_log_prob, index));
std::push_heap(topk_heap_.begin(), topk_heap_.end(), HeapCompare);
if (topk_heap_.size() > topk_k) {
std::pop_heap(topk_heap_.begin(), topk_heap_.end(), HeapCompare);
topk_heap_.pop_back();
}
}
}
std::sort(topk_heap_.begin(), topk_heap_.end(), HeapCompare);
}
std::vector<std::vector<int32_t>> BeamSearch::Process(int num_steps) {
// Encode();
std::vector<int32> input_indices(beam_size_, sos_id_);
// Favor beam index 0 for the first sos input.
beam_log_probabilities_[0] = 0.0f;
SequenceTracker sequence_tracker(beam_size_, eos_id_);
std::vector<int32_t> selected_beam(beam_size_, 0);
std::vector<std::vector<int32_t>> arrays;
arrays.emplace_back(num_steps * beam_size_);
arrays.emplace_back(num_steps * beam_size_);
int32_t *array_new = nullptr;
int valid_beam_entries = 1;
const float inverse_max_length_penalty = InverseLengthPenalty(num_steps);
for (int i = 0; i < num_steps; ++i) {
TfLiteTensor *decoder_output = Decode(i + 1, selected_beam, input_indices);
CHECK_EQ(decoder_output->dims->size, 3);
CHECK_EQ(decoder_output->dims->data[0], beam_size_);
CHECK_EQ(decoder_output->dims->data[1], 1);
CHECK_EQ(decoder_output->dims->data[2], num_classes_);
const float inverse_length_penalty = InverseLengthPenalty(i + 1);
if (decoder_output->type == kTfLiteUInt8) {
if (compute_topk_with_logits_) {
FindTopKQuantizedFromLogitsV1(*decoder_output, valid_beam_entries,
beam_size_ * 2);
} else {
FindTopKQuantized(*decoder_output, valid_beam_entries, beam_size_ * 2);
}
} else if (decoder_output->type == kTfLiteFloat32) {
LOG(ERROR) << "TopK is not optimized in this path.";
CHECK_EQ(compute_topk_with_logits_, false)
<< "TopK with logits for Float is not supported";
FindTopKFloat(*decoder_output, valid_beam_entries, beam_size_ * 2);
} else {
CHECK(false) << "Invalid data type: " << decoder_output->type;
}
const int32_t offset = i & 0x1;
const int32_t *array_old = arrays[1 - offset].data();
array_new = arrays[offset].data();
valid_beam_entries = 0;
for (int src = 0; src < beam_size_ * 2; ++src) {
const int new_class = (topk_heap_[src].second % num_classes_);
if (new_class == eos_id_) {
const int old_beam = topk_heap_[src].second / num_classes_;
sequence_tracker.AddSequence(
array_old + old_beam * num_steps,
array_old + old_beam * num_steps + i,
topk_heap_[src].first * inverse_length_penalty);
} else if (valid_beam_entries < beam_size_) {
if (valid_beam_entries != src) {
topk_heap_[valid_beam_entries] = topk_heap_[src];
}
valid_beam_entries++;
}
}
if (valid_beam_entries == 0) {
break;
}
const float max_alive_score =
topk_heap_[0].first * inverse_max_length_penalty;
if (max_alive_score < sequence_tracker.MaxTrackedScore()) {
break;
}
for (int j = 0; j < valid_beam_entries; ++j) {
beam_log_probabilities_[j] = topk_heap_[j].first;
const int new_class = (topk_heap_[j].second % num_classes_);
input_indices[j] = new_class;
const int old_beam = topk_heap_[j].second / num_classes_;
memcpy(array_new + j * num_steps, array_old + old_beam * num_steps,
i * sizeof(int32));
array_new[j * num_steps + i] = new_class;
if (debug_log_) PrintBeam(array_new + j * num_steps, i + 1);
selected_beam[j] = old_beam;
}
}
if (sequence_tracker.NumSequences() == 0) {
// No terminated sequence, the best alive sequence is the optimal one.
sequence_tracker.AddSequence(array_new, array_new + num_steps, 0.0f);
}
return sequence_tracker.GetTopBeams();
}
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#define THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#include <cstdint>
#include <functional>
#include <map>
#include <memory>
#include <set>
#include <vector>
#include "third_party/tensorflow/lite/c/common.h"
namespace seq_flow_lite {
namespace ops {
namespace custom {
class BeamSearchTestPeer;
// Implements Beam search util for decoding operations. The derived class
// should implement the Decode method to complete the actual decoding
// operation which outputs the probabilities for each beam and class.
class BeamSearch {
public:
BeamSearch(int beam_size, int num_classes, int sos_id, int eos_id,
float alpha = 0.6, bool use_logtis = false)
: beam_size_(beam_size),
num_classes_(num_classes),
sos_id_(sos_id),
eos_id_(eos_id),
alpha_(alpha),
beam_log_probabilities_(beam_size, 0.0f),
logits_mask_(num_classes, true),
compute_topk_with_logits_(use_logtis),
debug_log_(false) {
topk_heap_.reserve(2 * beam_size_);
}
// Virtual method that should be overridden to perform decode operations.
virtual TfLiteTensor* Decode(int timestep,
std::vector<int32_t>& selected_beams,
std::vector<int32_t>& input_indices) = 0;
virtual ~BeamSearch() {}
// Runs decoding process for num_steps.
std::vector<std::vector<int32_t>> Process(int num_steps);
int NumBeams() { return beam_size_; }
int NumClasses() { return num_classes_; }
void SetNumClasses(int num_classes) { num_classes_ = num_classes; }
// Sets boolean mask of size num_classes to process only valid logit indices.
// Example mask: {true, true, false, true, false} would result in processing
// logits at indices 0, 1 and 3.
void SetMaskForLogits(const std::vector<bool>& mask);
private:
friend class BeamSearchTestPeer;
// Floating point version of finding top_k classes from decoder output.
void FindTopKFloat(const TfLiteTensor& tensor, int valid_beams, int K);
// Quantized version of finding top_k classes from decoder output probs.
void FindTopKQuantized(const TfLiteTensor& tensor, int valid_beams, int K);
// Quantized version of finding top_k classes from decoder output logits.
void FindTopKQuantizedFromLogits(const TfLiteTensor& tensor, int valid_beams,
int topk_k);
// Optimized version for FindTopKQuantizedFromLogits.
void FindTopKQuantizedFromLogitsV1(const TfLiteTensor& tensor,
int valid_beams, int topk_k);
// Length penalty is given by = (5+len(decode)/6) ^ -\alpha.
// Pls refer to https://arxiv.org/abs/1609.08144.
float InverseLengthPenalty(int step);
// Populates log probabilities for int values 0-255.
void PopulateLogLookupTable(const TfLiteTensor& tensor);
// Populates exp probabilities for int values 0-255.
void PopulateSoftmaxLookupTable(const TfLiteTensor& tensor);
std::vector<std::pair<float, int32_t>> topk_heap_;
const int beam_size_;
int num_classes_;
// Start of sequence ID.
const int sos_id_;
// End of sequence ID.
const int eos_id_;
// Alpha to be used in length penality computation.
const float alpha_;
std::vector<float> beam_log_probabilities_;
// Mask for valid logits. Used when computing TopK with logits.
std::vector<bool> logits_mask_;
// Computes TopK using logits instead of probabilities.
bool compute_topk_with_logits_ = false;
float log_lookup_table_[256];
bool log_lookup_table_populated_ = false;
float exp_lookup_table_[256];
bool exp_lookup_table_populated_ = false;
bool debug_log_;
};
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
#endif // THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/beam_search.h"
#include <cstdint>
#include <functional>
#include <iostream>
#include <memory>
#include <vector>
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
#include "third_party/absl/strings/str_join.h"
#include "third_party/tensorflow/lite/c/c_api_types.h"
#include "third_party/tensorflow/lite/c/common.h"
#include "third_party/tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "third_party/tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "third_party/tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/internal/types.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
namespace seq_flow_lite {
namespace ops {
namespace custom {
void CheckOutputData(const float* test_output, const float* reference_output,
const ::tflite::RuntimeShape& shape_common) {
const int buffer_size = shape_common.FlatSize();
double sum_abs_diff = 0;
float max_abs_val = 0;
for (int i = 0; i < buffer_size; i++) {
sum_abs_diff += std::abs(test_output[i] - reference_output[i]);
max_abs_val = std::max(max_abs_val, std::abs(reference_output[i]));
}
if (sum_abs_diff != 0.f) {
const float mean_diff = static_cast<float>(sum_abs_diff / buffer_size);
const float relative_error = std::abs(mean_diff) / max_abs_val;
ASSERT_LT(relative_error, 1e-5f);
}
}
class BeamSearchImpl : public BeamSearch {
public:
BeamSearchImpl(int beam_size, int num_classes, int sos_id, int eos_id,
bool use_logits = false, bool quantize = false)
: BeamSearch(beam_size, num_classes, sos_id, eos_id,
/*alpha=*/0.6, /*use_logtis=*/use_logits) {
CreateDecoderOutputTensor({beam_size, 1, num_classes}, quantize);
InitializeCache();
}
TfLiteTensor* Decode(int timestep, std::vector<int32_t>& selected_beams,
std::vector<int32_t>& indices) override {
const float* cur_cache = CurrentCache(timestep);
float* next_cache = NextCache(timestep);
if (decoder_output_->type == kTfLiteUInt8) {
auto data_ptr = ::tflite::GetTensorData<uint8_t>(decoder_output_.get());
for (int beam = 0, index = 0; beam < NumBeams(); ++beam) {
const float* selected =
cur_cache + (selected_beams[beam] * NumClasses());
for (int j = 0; j < NumClasses(); ++j, index++) {
next_cache[index] = (selected[j] + next_cache[index]) / 2;
data_ptr[index] = ::seq_flow_lite::PodQuantize(
next_cache[index], decoder_output_->params.zero_point,
1.0f / decoder_output_->params.scale);
}
}
} else {
auto data_ptr = ::tflite::GetTensorData<float>(decoder_output_.get());
for (int beam = 0, index = 0; beam < NumBeams(); ++beam) {
const float* selected =
cur_cache + (selected_beams[beam] * NumClasses());
for (int j = 0; j < NumClasses(); ++j, index++) {
next_cache[index] = (selected[j] + next_cache[index]) / 2;
data_ptr[index] = next_cache[index];
}
}
}
return decoder_output_.get();
}
private:
void CreateDecoderOutputTensor(const std::vector<int>& dims,
bool quantize = false) {
decoder_output_.reset(new TfLiteTensor);
decoder_output_->dims = TfLiteIntArrayCreate(dims.size());
int tensor_size = 1;
for (int i = 0; i < dims.size(); ++i) {
decoder_output_->dims->data[i] = dims[i];
tensor_size *= dims[i];
}
if (quantize) {
decoder_output_->type = kTfLiteUInt8;
decoder_output_->bytes = tensor_size * sizeof(uint8_t);
decoder_output_->params.scale = 1.0 / 255.0;
decoder_output_->params.zero_point = 0;
} else {
decoder_output_->type = kTfLiteFloat32;
decoder_output_->bytes = tensor_size * sizeof(float);
}
decoder_output_->data.raw = new char[decoder_output_->bytes];
}
struct DeleteTensor {
void operator()(TfLiteTensor* t) const {
TfLiteIntArrayFree(t->dims);
delete[] t->data.raw;
delete t;
}
};
float* CurrentCache(int step) {
return (step & 0x1) == 0x1 ? cache1_.data() : cache2_.data();
}
float* NextCache(int step) {
return (step & 0x1) == 0x1 ? cache2_.data() : cache1_.data();
}
void InitializeCache() {
cache1_ = {/* 0: */ 0.6, 0.8, 0.3, 0.7, 0.2,
/* 1: */ 0.5, 0.2, 0.1, 0.3, 0.4};
cache2_ = {/* 0: */ 0.6, 0.9, 0.8, 0.2, 0.8,
/* 1: */ 0.5, 0.8, 0.5, 0.7, 0.9};
}
std::unique_ptr<TfLiteTensor, DeleteTensor> decoder_output_;
std::vector<float> cache1_{20, 0.0};
std::vector<float> cache2_{20, 0.0};
};
class BeamSearchTestPeer {
public:
BeamSearchTestPeer(int beam_size, int num_classes, int sos_id, int eos_id,
bool use_logits = false, bool quantize = false)
: beam_size_(beam_size),
num_classes_(num_classes),
sos_id_(sos_id),
eos_id_(eos_id),
use_logits_(use_logits),
quantize_(quantize) {}
std::vector<std::vector<int32_t>> Process(int num_steps) {
BeamSearchImpl bs(beam_size_, num_classes_, sos_id_, eos_id_, use_logits_,
quantize_);
return bs.Process(num_steps);
}
std::vector<float> InvokeFindTopKQuantizedWithLogits(
const TfLiteTensor& logits, const std::vector<bool>& mask,
int valid_beams, int topk_k, bool optimized = false) {
BeamSearchImpl bs(beam_size_, num_classes_, sos_id_, eos_id_, use_logits_,
quantize_);
bs.SetMaskForLogits(mask);
if (optimized) {
bs.FindTopKQuantizedFromLogitsV1(logits, valid_beams, topk_k);
} else {
bs.FindTopKQuantizedFromLogits(logits, valid_beams, topk_k);
}
std::vector<float> result;
for (int i = 0; i < topk_k; ++i) {
result.push_back(bs.topk_heap_[i].first);
}
return result;
}
private:
int beam_size_;
int num_classes_;
int sos_id_;
int eos_id_;
bool use_logits_;
bool quantize_;
};
TEST(BeamSearch, BasicTest) {
BeamSearchTestPeer bst(2, 5, 0, 2);
auto beams = bst.Process(4);
EXPECT_EQ(absl::StrJoin(beams[0], ","), "2");
EXPECT_EQ(absl::StrJoin(beams[1], ","), "1,2");
}
TEST(BeamSearch, BasicTestQuantized) {
BeamSearchTestPeer bst(2, 5, 0, 2, /*use_logits*/ false, /*quantize=*/true);
auto beams = bst.Process(4);
EXPECT_EQ(absl::StrJoin(beams[0], ","), "2");
EXPECT_EQ(absl::StrJoin(beams[1], ","), "1,2");
}
TEST(BeamSearch, TestFindTopKQuantizedFromLogits) {
int beam_size = 2;
int num_classes = 5;
BeamSearchImpl bs(beam_size, num_classes, 0, 2, /*use_logits=*/true,
/*quantize=*/true);
std::vector<int32_t> selected_beams = {0, 1};
std::vector<int32> input_indices(2, 0);
auto* logits_tensor = bs.Decode(1, selected_beams, input_indices);
BeamSearchTestPeer bst(beam_size, num_classes, 0, 2, /*use_logits=*/true,
/*quantize=*/true);
std::vector<bool> mask(num_classes, true);
auto topk_output = bst.InvokeFindTopKQuantizedWithLogits(
*logits_tensor, mask, beam_size, beam_size * num_classes);
auto shape_common = ::tflite::RuntimeShape({beam_size, 1, num_classes});
const int buffer_size = shape_common.FlatSize();
std::vector<float> reference_dequant_data(buffer_size);
std::vector<float> reference_output_float_data(buffer_size);
::tflite::DequantizationParams dq_params;
dq_params.zero_point = logits_tensor->params.zero_point;
dq_params.scale = logits_tensor->params.scale;
::tflite::reference_ops::Dequantize(dq_params, shape_common,
logits_tensor->data.uint8, shape_common,
reference_dequant_data.data());
::tflite::SoftmaxParams sm_params;
::tflite::optimized_ops::LogSoftmax(
sm_params, shape_common, reference_dequant_data.data(), shape_common,
reference_output_float_data.data());
std::sort(reference_output_float_data.begin(),
reference_output_float_data.end(), std::greater<float>());
CheckOutputData(topk_output.data(), reference_output_float_data.data(),
shape_common);
}
TEST(BeamSearch, TestFindTopKQuantizedFromLogitsV1) {
int beam_size = 2;
int num_classes = 5;
BeamSearchImpl bs(beam_size, num_classes, 0, 2, /*use_logits=*/true,
/*quantize=*/true);
std::vector<int32_t> selected_beams = {0, 1};
std::vector<int32> input_indices(2, 0);
auto* logits_tensor = bs.Decode(1, selected_beams, input_indices);
BeamSearchTestPeer bst(beam_size, num_classes, 0, 2, /*use_logits=*/true,
/*quantize=*/true);
int topk_k = beam_size * 2;
std::vector<bool> mask = {true, true, false, true, false};
auto topk_output = bst.InvokeFindTopKQuantizedWithLogits(*logits_tensor, mask,
beam_size, topk_k);
auto topk_output_v1 = bst.InvokeFindTopKQuantizedWithLogits(
*logits_tensor, mask, beam_size, topk_k, /*optimized=*/true);
auto shape_common = ::tflite::RuntimeShape({beam_size, 1, 1});
CheckOutputData(topk_output_v1.data(), topk_output.data(), shape_common);
}
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#define THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#include <memory>
#include "third_party/tensorflow/lite/c/common.h"
namespace seq_flow_lite {
namespace ops {
namespace custom {
namespace tflite_decoder_base {
// Base decoder op that can be derived to implement different decoding schemes.
template <typename T>
class BaseDecoderOp {
public:
explicit BaseDecoderOp(int feature_size, int beam_size)
: feature_size_(feature_size),
beam_size_(beam_size),
cache1_(new T[feature_size * beam_size]),
cache2_(new T[feature_size * beam_size]) {}
virtual ~BaseDecoderOp() {}
int BeamSize() const { return beam_size_; }
int FeatureSize() const { return feature_size_; }
virtual void InitCache(TfLiteTensor* cache = nullptr) {
memset(cache1_.get(), 0, beam_size_ * feature_size_ * sizeof(T));
}
T* CurrentCache(int step) const {
return (step & 0x1) == 0x1 ? cache1_.get() : cache2_.get();
}
T* NextCache(int step) const {
return (step & 0x1) == 0x1 ? cache2_.get() : cache1_.get();
}
private:
const int feature_size_;
const int beam_size_;
const std::unique_ptr<T[]> cache1_;
const std::unique_ptr<T[]> cache2_;
};
// DynamicCacheOp stores caches of different timesteps. It supports reallocate
// memory for past timestep when beam size is dynamically added.
template <typename T>
class DynamicCacheOp {
public:
explicit DynamicCacheOp(int feature_size) : feature_size_(feature_size) {}
virtual ~DynamicCacheOp() {}
int FeatureSize() const { return feature_size_; }
virtual void InitCache(TfLiteTensor* cache = nullptr) { cache_list_.clear(); }
// GetCache is called by the new step in UnifromAttn. The caller wants to add
// a new cache or dynamically appends attn value to an existing cache.
std::vector<T>* GetCache(int step, int beam_size) {
// If the wanted cache is larger than cache_list_.size(), will return a
// invalid pointer. There may be an error of the step, and the caller should
// stop using cache.
if (step - 1 > cache_list_.size()) {
return nullptr;
} else if (step - 1 == cache_list_.size()) {
// The caller wants to add a new cache if the wanted step equals the size
// of cache_list_.
cache_list_.push_back(
std::move(std::vector<T>(feature_size_ * beam_size)));
} else {
// Allocates new memory in previous cache to store new uniform attention.
cache_list_[step - 1].resize(cache_list_[step - 1].size() +
beam_size * feature_size_);
}
return &cache_list_[step - 1];
}
// GetStaticCache will return the cached attention which is readonly.
std::vector<T>* GetStaticCache(int step) {
// No previous cache for the initial step.
if (step == 0) {
return nullptr;
} else {
// Gets the previous cache.
return &cache_list_[step - 1];
}
}
private:
const int feature_size_;
std::vector<std::vector<T>> cache_list_;
};
} // namespace tflite_decoder_base
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
#endif // THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_handler.h"
#include <cstdint>
#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/kernel_util.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_cache.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
namespace seq_flow_lite {
namespace ops {
namespace custom {
namespace {
static constexpr const char kFeatureSizeStr[] = "feature_size";
static constexpr const char kBeamSizeStr[] = "beam_size";
constexpr int kInputFeaturesIndex = 0;
constexpr int kTimestepIndex = 1;
constexpr int kSelectedBeamsIndex = 2;
constexpr int kOutputFeaturesIndex = 0;
} // namespace
namespace tflite_decoder_uniform {
// Evaluates uniform average decoding operations.
class UniformDecoderOp : public tflite_decoder_base::BaseDecoderOp<float> {
public:
explicit UniformDecoderOp(int feature_size, int beam_size)
: BaseDecoderOp(feature_size, beam_size) {}
void Eval(int32_t step, const std::vector<int32_t>& selected_beams,
const float* update, float* result);
void EvalQuantized(int32_t step, const std::vector<int32_t>& selected_beams,
const TfLiteTensor* input, TfLiteTensor* output);
};
void UniformDecoderOp::Eval(int32_t step,
const std::vector<int32_t>& selected_beams,
const float* update, float* result) {
const float normalizer = 1.0f / step;
const float* cur_cache = CurrentCache(step);
float* next_cache = NextCache(step);
for (int i = 0, index = 0; i < BeamSize(); ++i) {
const float* selected = cur_cache + (selected_beams[i] * FeatureSize());
for (int j = 0; j < FeatureSize(); ++j, index++) {
next_cache[index] = selected[j] + update[index];
result[index] = next_cache[index] * normalizer;
}
}
}
void UniformDecoderOp::EvalQuantized(int32_t step,
const std::vector<int32_t>& selected_beams,
const TfLiteTensor* input,
TfLiteTensor* output) {
uint8_t* result = ::tflite::GetTensorData<uint8_t>(output);
const float normalizer_and_inverse_scale =
1.0f / (output->params.scale * step);
const float* cur_cache = CurrentCache(step);
float* next_cache = NextCache(step);
for (int i = 0, index = 0; i < BeamSize(); ++i) {
const float* selected = cur_cache + (selected_beams[i] * FeatureSize());
for (int j = 0; j < FeatureSize(); ++j, index++) {
next_cache[index] =
selected[j] + ::seq_flow_lite::PodDequantize(*input, index);
result[index] = ::seq_flow_lite::PodQuantize(
next_cache[index], output->params.zero_point,
normalizer_and_inverse_scale);
}
}
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
return new UniformDecoderOp(m[kFeatureSizeStr].AsInt32(),
m[kBeamSizeStr].AsInt32());
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<UniformDecoderOp*>(buffer);
}
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, ::tflite::NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, ::tflite::NumOutputs(node), 1);
const TfLiteTensor* input =
::tflite::GetInput(context, node, kInputFeaturesIndex);
TfLiteTensor* output =
::tflite::GetOutput(context, node, kOutputFeaturesIndex);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, ::tflite::NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, ::tflite::NumOutputs(node), 1);
auto* params = reinterpret_cast<UniformDecoderOp*>(node->user_data);
const TfLiteTensor* input =
::tflite::GetInput(context, node, kInputFeaturesIndex);
const TfLiteTensor* time_step =
::tflite::GetInput(context, node, kTimestepIndex);
const TfLiteTensor* selected_beams =
::tflite::GetInput(context, node, kSelectedBeamsIndex);
TF_LITE_ENSURE_EQ(context, time_step->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, time_step->dims->size, 0);
TF_LITE_ENSURE_EQ(context, selected_beams->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, selected_beams->dims->size, 1);
TF_LITE_ENSURE_EQ(context, selected_beams->dims->data[0], params->BeamSize());
const int32_t time_step_value =
::tflite::GetTensorData<int32_t>(time_step)[0];
const int32_t* selected_beams_ptr =
::tflite::GetTensorData<int32_t>(selected_beams);
const std::vector<int32_t> selected_beams_value(
selected_beams_ptr, selected_beams_ptr + params->BeamSize());
for (auto value : selected_beams_value) {
TF_LITE_ENSURE(context, value >= 0 && value < params->BeamSize());
}
TfLiteTensor* output =
::tflite::GetOutput(context, node, kOutputFeaturesIndex);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
// Time step is expected to be in [1, )
TF_LITE_ENSURE(context, time_step_value >= 1);
if (time_step_value == 1) {
params->InitCache();
}
if (input->type == kTfLiteFloat32) {
params->Eval(time_step_value, selected_beams_value,
::tflite::GetTensorData<float>(input),
::tflite::GetTensorData<float>(output));
} else if (input->type == kTfLiteUInt8) {
params->EvalQuantized(time_step_value, selected_beams_value, input, output);
} else {
context->ReportError(context, "Op type must be Float32 or UInt8.");
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace tflite_decoder_uniform
TfLiteRegistration* Register_UNIFORM_CAUSAL_ATTENTION() {
static TfLiteRegistration r = {
tflite_decoder_uniform::Init, tflite_decoder_uniform::Free,
tflite_decoder_uniform::Resize, tflite_decoder_uniform::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#define THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#include "third_party/tensorflow/lite/kernels/register.h"
namespace seq_flow_lite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_UNIFORM_CAUSAL_ATTENTION();
}
} // namespace ops
} // namespace seq_flow_lite
#endif // THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_handler.h"
#include <cstdint>
#include <cstdlib>
#include <vector>
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
#include "third_party/tensorflow/lite/c/common.h"
#include "third_party/tensorflow/lite/kernels/test_util.h"
namespace {
constexpr char kUniformAverageAttention[] = "UniformAverageAttentionDecoder";
class AverageAttentionDecoder : public tflite::SingleOpModel {
public:
explicit AverageAttentionDecoder(int feature_size, int beam_size,
bool quantized = false)
: quantized_(quantized) {
flexbuffers::Builder fbb;
fbb.Map([&] {
fbb.Int("feature_size", feature_size);
fbb.Int("beam_size", beam_size);
});
fbb.Finish();
if (!quantized) {
input_ =
AddInput({tflite::TensorType_FLOAT32, {beam_size, 1, feature_size}});
output_ =
AddOutput({tflite::TensorType_FLOAT32, {beam_size, 1, feature_size}});
} else {
input_ = AddInput(
{tflite::TensorType_UINT8, {beam_size, 1, feature_size}, 0.0f, 4.0f});
output_ = AddOutput(
{tflite::TensorType_UINT8, {beam_size, 1, feature_size}, 0.0f, 4.0f});
}
timestep_ = AddInput({tflite::TensorType_INT32, {}});
beam_ = AddInput({tflite::TensorType_INT32, {beam_size}});
SetCustomOp(
kUniformAverageAttention, fbb.GetBuffer(),
::seq_flow_lite::ops::custom::Register_UNIFORM_CAUSAL_ATTENTION);
BuildInterpreter({GetShape(input_), GetShape(timestep_), GetShape(beam_)});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
}
TfLiteStatus Invoke(int timestep, const std::vector<int32_t>& beams,
const std::vector<float>& input_val) {
PopulateTensor<int32_t>(timestep_, {timestep});
PopulateTensor<int32_t>(beam_, beams);
if (!quantized_) {
PopulateTensor<float>(input_, input_val);
} else {
QuantizeAndPopulate<uint8_t>(input_, input_val);
}
return SingleOpModel::Invoke();
}
std::vector<float> GetOutput() {
if (!quantized_) {
return ExtractVector<float>(output_);
} else {
return tflite::Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
GetScale(output_),
GetZeroPoint(output_));
}
}
private:
int input_;
int output_;
int timestep_;
int beam_;
bool quantized_;
};
TEST(AverageAttentionDecoder, RegularInput) {
AverageAttentionDecoder m(4, 4);
auto status = m.Invoke(1, {0, 0, 0, 0},
{1.f, 1.f, 1.f, 1.f, //
2.f, 2.f, 2.f, 2.f, //
3.f, 3.f, 3.f, 3.f, //
4.f, 4.f, 4.f, 4.f});
EXPECT_EQ(status, kTfLiteOk);
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray({1.f, 1.f, 1.f, 1.f, //
2.f, 2.f, 2.f, 2.f, //
3.f, 3.f, 3.f, 3.f, //
4.f, 4.f, 4.f, 4.f}));
status = m.Invoke(2, {2, 3, 1, 1},
{1.f, 1.f, 1.f, 1.f, //
2.f, 2.f, 2.f, 2.f, //
3.f, 3.f, 3.f, 3.f, //
4.f, 4.f, 4.f, 4.f});
EXPECT_EQ(status, kTfLiteOk);
EXPECT_THAT(m.GetOutput(),
testing::ElementsAreArray({2.f, 2.f, 2.f, 2.f, //
3.f, 3.f, 3.f, 3.f, //
2.5f, 2.5f, 2.5f, 2.5f, //
3.f, 3.f, 3.f, 3.f}));
}
TEST(AverageAttentionDecoder, RegularInputQuantized) {
AverageAttentionDecoder m(4, 4, true);
auto status = m.Invoke(1, {0, 0, 0, 0},
{1.f, 1.f, 1.f, 1.f, //
2.f, 2.f, 2.f, 2.f, //
3.f, 3.f, 3.f, 3.f, //
4.f, 4.f, 4.f, 4.f});
EXPECT_EQ(status, kTfLiteOk);
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(tflite::ArrayFloatNear({1.f, 1.f, 1.f, 1.f, //
2.f, 2.f, 2.f, 2.f, //
3.f, 3.f, 3.f, 3.f, //
4.f, 4.f, 4.f, 4.f},
1e-2)));
EXPECT_EQ(status, kTfLiteOk);
status = m.Invoke(2, {2, 3, 1, 1},
{1.f, 1.f, 1.f, 1.f, //
2.f, 2.f, 2.f, 2.f, //
3.f, 3.f, 3.f, 3.f, //
4.f, 4.f, 4.f, 4.f});
EXPECT_EQ(status, kTfLiteOk);
EXPECT_THAT(m.GetOutput(), ElementsAreArray(tflite::ArrayFloatNear(
{2.f, 2.f, 2.f, 2.f, //
3.f, 3.f, 3.f, 3.f, //
2.5f, 2.5f, 2.5f, 2.5f, //
3.f, 3.f, 3.f, 3.f},
1e-2)));
}
TEST(AverageAttentionDecoder, RandomInput) {
AverageAttentionDecoder m(4, 4);
std::vector<float> input = {2.1, 3.1, -1.6, 11.3, //
22.6, 20.8, 32.2, -12.9, //
13.2, 3.3, -3.0, 33.3, //
24.3, 14.9, -4.9, 4.7};
auto status = m.Invoke(1, {0, 0, 0, 0}, input);
EXPECT_EQ(status, kTfLiteOk);
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(input));
status = m.Invoke(2, {2, 3, 1, 1}, input);
EXPECT_EQ(status, kTfLiteOk);
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(tflite::ArrayFloatNear(
{7.65, 3.2, -2.3, 22.3, //
23.45, 17.85, 13.65, -4.1, //
17.9, 12.05, 14.6, 10.2, //
23.45, 17.85, 13.65, -4.1},
1e-2)));
}
TEST(AverageAttentionDecoder, IrrregularInput) {
AverageAttentionDecoder m(4, 4, false);
auto status = m.Invoke(1, {20, 3, 2, 0},
{1.f, 1.f, 1.f, 1.f, //
2.f, 2.f, 2.f, 2.f, //
3.f, 3.f, 3.f, 3.f, //
4.f, 4.f});
EXPECT_EQ(status, kTfLiteError);
status = m.Invoke(-10, {0, 3, 2, 0},
{1.f, 1.f, 1.f, 1.f, //
2.f, 2.f, 2.f, 2.f, //
3.f, 3.f, 3.f, 3.f, //
4.f, 4.f});
EXPECT_EQ(status, kTfLiteError);
status = m.Invoke(1, {0, 3, 2, 0},
{1.f, 1.f, 1.f, 1.f, //
2.f, 2.f, 2.f, 2.f, //
3.f, 3.f, 3.f, 3.f, //
4.f, 4.f});
EXPECT_EQ(status, kTfLiteOk);
}
} // namespace
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
namespace seq_flow_lite {
namespace {
const uint8_t kPoolingForward = 255;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 3);
if (node->outputs->size < 1 || node->outputs->size > 2) {
return kTfLiteError;
}
TfLiteTensor* multiplier = &context->tensors[node->inputs->data[0]];
TfLiteTensor* constant = &context->tensors[node->inputs->data[1]];
TfLiteTensor* direction = &context->tensors[node->inputs->data[2]];
TF_LITE_ENSURE_EQ(context, multiplier->type, kTfLiteUInt8);
TF_LITE_ENSURE_EQ(context, constant->type, kTfLiteUInt8);
TF_LITE_ENSURE_EQ(context, direction->type, kTfLiteUInt8);
TF_LITE_ENSURE_EQ(context, multiplier->dims->size, 3);
TF_LITE_ENSURE_EQ(context, multiplier->dims->data[0], 1);
const int time_steps = multiplier->dims->data[1];
const int state_size = multiplier->dims->data[2];
TF_LITE_ENSURE_EQ(context, constant->dims->size, 3);
TF_LITE_ENSURE_EQ(context, constant->dims->data[0], 1);
TF_LITE_ENSURE_EQ(context, constant->dims->data[1], time_steps);
TF_LITE_ENSURE_EQ(context, constant->dims->data[2], state_size);
TF_LITE_ENSURE_EQ(context, direction->dims->size, 1);
TF_LITE_ENSURE_EQ(context, direction->dims->data[0], 1);
TfLiteTensor* outputs = &context->tensors[node->outputs->data[0]];
if (outputs) {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, outputs,
TfLiteIntArrayCopy(multiplier->dims)));
}
if (node->outputs->size == 2) {
TfLiteTensor* final_state = &context->tensors[node->outputs->data[1]];
if (final_state) {
TfLiteIntArray* final_state_dims = TfLiteIntArrayCreate(2);
final_state_dims->data[0] = 1;
final_state_dims->data[1] = state_size;
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, final_state,
final_state_dims));
}
}
return kTfLiteOk;
}
TfLiteStatus QRNNPooling(TfLiteContext* context, TfLiteTensor* multiplier,
TfLiteTensor* constant, TfLiteTensor* outputs,
TfLiteTensor* final_state, bool forward) {
const int time_steps = multiplier->dims->data[1];
const int state_size = multiplier->dims->data[2];
auto state = std::make_unique<float[]>(state_size);
memset(state.get(), 0, sizeof(float) * state_size);
const int32_t out_zero_point = outputs ? outputs->params.zero_point : 0;
const float out_inverse_scale = outputs ? 1.0f / outputs->params.scale : 1.0f;
uint8_t* out_ptr = outputs ? outputs->data.uint8 : nullptr;
for (int i = 0; i < time_steps; ++i) {
for (int j = 0; j < state_size; ++j) {
const int time_index = forward ? i : time_steps - (i + 1);
const int index = time_index * state_size + j;
float multiplier_value = PodDequantize(*multiplier, index);
float constant_vale = PodDequantize(*constant, index);
state[j] = state[j] * multiplier_value + constant_vale;
if (outputs) {
out_ptr[index] =
PodQuantize(state[j], out_zero_point, out_inverse_scale);
}
}
}
if (final_state) {
uint8_t* final_state_ptr = final_state->data.uint8;
const int32_t zero_point = final_state->params.zero_point;
const float inverse_scale = 1.0f / final_state->params.scale;
for (int j = 0; j < state_size; ++j) {
final_state_ptr[j] = PodQuantize(state[j], zero_point, inverse_scale);
}
}
return kTfLiteOk;
}
TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 3);
if (node->outputs->size < 1 || node->outputs->size > 2) {
return kTfLiteError;
}
TfLiteTensor* multiplier = &context->tensors[node->inputs->data[0]];
TfLiteTensor* constant = &context->tensors[node->inputs->data[1]];
TfLiteTensor* direction = &context->tensors[node->inputs->data[2]];
TfLiteTensor* outputs = &context->tensors[node->outputs->data[0]];
TfLiteTensor* final_state = (node->outputs->size == 2)
? &context->tensors[node->outputs->data[1]]
: nullptr;
// When pooling forward the direction parameter is expected to be
// kPoolingForward.
return QRNNPooling(context, multiplier, constant, outputs, final_state,
(direction->data.uint8[0] == kPoolingForward));
}
} // namespace
namespace custom {
const char kPoolingOp[] = "PoolingOp";
void RegisterQRNNPooling(::tflite::ops::builtin::BuiltinOpResolver* resolver) {
resolver->AddCustom(kPoolingOp, Register_QRNN_POOLING());
}
TfLiteRegistration* Register_QRNN_POOLING() {
static TfLiteRegistration r = {nullptr, nullptr, Prepare, Invoke};
return &r;
}
} // namespace custom
} // namespace seq_flow_lite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#include "third_party/absl/base/macros.h"
#include "third_party/tensorflow/lite/kernels/register.h"
namespace seq_flow_lite {
namespace custom {
extern const char kPoolingOp[];
TfLiteRegistration* Register_QRNN_POOLING();
} // namespace custom
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment