Unverified Commit f91b59c6 authored by thunderfyc's avatar thunderfyc Committed by GitHub
Browse files

Initial checkin of sequence_projection (#9153)



* Initial checkin of sequence_projection

* Fix the path

* Fix paths and deps

* Fix path and deps
Co-authored-by: default avatarLearn2Compress <expander-robot@google.com>
parent 67efd3ab
# TFLite ops for sequence string projection.
load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_copts")
licenses(["notice"])
package(
default_visibility = [
"//:__subpackages__", # sequence projection
],
)
cc_library(
name = "sequence_string_projection",
srcs = ["sequence_string_projection.cc"],
hdrs = ["sequence_string_projection.h"],
copts = tflite_copts(),
deps = [
":quantization_util",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"///tf_ops:projection_normalizer_util", # sequence projection
"///tf_ops:projection_util", # sequence projection
"@flatbuffers",
],
alwayslink = 1,
)
cc_test(
name = "sequence_string_projection_test",
size = "small",
srcs = ["sequence_string_projection_test.cc"],
deps = [
":sequence_string_projection",
":tf_tflite_diff_test_util",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/core/api",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:test_util",
"///tf_ops:sequence_string_projection_op", # sequence projection
"///tf_ops:sequence_string_projection_op_v2", # sequence projection
"@flatbuffers",
],
)
cc_library(
name = "tf_tflite_diff_test_util",
testonly = 1,
srcs = ["tf_tflite_diff_test_util.cc"],
hdrs = ["tf_tflite_diff_test_util.h"],
deps = [
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
"@org_tensorflow//tensorflow/core/kernels:ops_testutil",
"@org_tensorflow//tensorflow/lite/kernels:test_util",
"@com_google_absl//absl/container:flat_hash_map",
"@flatbuffers",
],
)
cc_library(
name = "quantization_util",
hdrs = ["quantization_util.h"],
deps = ["@org_tensorflow//tensorflow/lite:context"],
)
cc_library(
name = "expected_value",
srcs = ["expected_value.cc"],
hdrs = ["expected_value.h"],
copts = tflite_copts(),
deps = [
":quantization_util",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
alwayslink = 1,
)
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tflite_ops/expected_value.h" // sequence_projection
#include <cmath>
#include "tflite_ops/quantization_util.h" // sequence_projection
namespace tflite {
namespace ops {
namespace custom {
namespace {
constexpr int kInputAttentionLogits = 0;
constexpr int kInputValues = 1;
constexpr int kOutputExpectedValue = 0;
class ExpectedValueParams {
public:
// Get precomputed exponential table for the quantization range of the tensor.
// The table is precomputed during first lookup and used till the tflite
// interpreter is destroyed.
float* GetPrecomputedTable(const TfLiteTensor& tensor) {
if (!initialized_) {
initialized_ = true;
const float scale = tensor.params.scale;
for (int i = 0;
i < sizeof(precomputed_table_) / sizeof(precomputed_table_[0]);
++i) {
precomputed_table_[i] = expf(-i * scale);
}
}
return precomputed_table_;
}
private:
bool initialized_ = false;
float precomputed_table_[256];
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return new ExpectedValueParams();
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<ExpectedValueParams*>(buffer);
}
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* attention_logits =
&context->tensors[node->inputs->data[kInputAttentionLogits]];
TfLiteTensor* values = &context->tensors[node->inputs->data[kInputValues]];
// Currently only 8-bit input tensors are supported.
TF_LITE_ENSURE_EQ(context, attention_logits->type, kTfLiteUInt8);
TF_LITE_ENSURE_EQ(context, values->type, kTfLiteUInt8);
// Both the input tensors are expected to be rank 3.
TF_LITE_ENSURE_EQ(context, attention_logits->dims->size, 3);
TF_LITE_ENSURE_EQ(context, attention_logits->dims->size, values->dims->size);
// Currently batch size is expected to be 1.
TF_LITE_ENSURE_EQ(context, attention_logits->dims->data[0], 1);
// Dimensions of both the input tensors should match.
for (int i = 0; i < values->dims->size; ++i) {
TF_LITE_ENSURE_EQ(context, attention_logits->dims->data[i],
values->dims->data[i]);
}
TfLiteTensor* output =
&context->tensors[node->outputs->data[kOutputExpectedValue]];
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteUInt8);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
// Expectation is over dimension 2 leaving a rank 2 output tensor with first
// and last dimension as the input.
output_size->data[0] = values->dims->data[0];
output_size->data[1] = values->dims->data[2];
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto logits_t = &context->tensors[node->inputs->data[kInputAttentionLogits]];
auto values_t = &context->tensors[node->inputs->data[kInputValues]];
auto output_t = &context->tensors[node->outputs->data[kOutputExpectedValue]];
const int out_channels = logits_t->dims->data[2];
const int sequence_length = logits_t->dims->data[1];
const float out_inverse_scale = 1.0f / output_t->params.scale;
const int32_t out_zero_point = output_t->params.zero_point;
uint8_t* output = output_t->data.uint8;
auto* params = reinterpret_cast<ExpectedValueParams*>(node->user_data);
const float* table = params->GetPrecomputedTable(*logits_t);
// Memory layout of the input tensor is row-major, hence the inner loops have
// a pitch of out_channels instead of 1. The inner loop runs over this array
// two times for logits and once for values. If the out_channels increases
// beyond a reasonable value, the entire content of logits/values won't fit in
// L1 cache, which would make these loops very inefficient. If the last
// dimension increases, this handler should be rewritten to do transpose first
// in a cache efficient manner before performing the compute.
for (int i = 0; i < out_channels; ++i) {
// Find max logit, max logit is subtracted to ensure numerical stability
// when computing softmax.
auto slogits = &logits_t->data.uint8[i];
auto elogits = slogits + (sequence_length * out_channels);
int32_t maxval = 0;
for (auto logits = slogits; logits < elogits; logits += out_channels) {
maxval = std::max(static_cast<int32_t>(*logits), maxval);
}
// Find normalizer to compute softmax (sum of exponential over logits).
// Compute the softmax output (attention), perform the elementwise
// multiplication and reduce by summing in a single loop. This results in
// the unnormalized expected value, which is normalized later.
float normalizer = 0.0f;
float unnormalized_expected_value = 0.0f;
auto values = &values_t->data.uint8[i];
for (auto logits = slogits; logits < elogits;
logits += out_channels, values += out_channels) {
const float unnormalized_attention = table[maxval - *logits];
normalizer += unnormalized_attention;
unnormalized_expected_value +=
unnormalized_attention * PodDequantizeValue(*values_t, *values);
}
const float expected_value = unnormalized_expected_value / normalizer;
// Quantize and set the expected value in the output buffer.
output[i] = PodQuantize(expected_value, out_zero_point, out_inverse_scale);
}
return kTfLiteOk;
}
} // namespace
// This tflite fused op takes two input tensors (logits and values), which are
// expected to be rank 3 tensors of the form [batch size, sequence, channels].
// The op performs softmax on the sequence dimension of logits input, performs
// an element-wise multiplication with the values tensor, reduces the sequence
// dimension to a scalar value using sum operation and returns a tensor of the
// form [batch size, channels]. Batch size is assumed to be 1 in the current
// implementation.
TfLiteRegistration* Register_EXPECTED_VALUE() {
static TfLiteRegistration r = {Init, Free, Resize, Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_EXPECTED_VALUE_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_EXPECTED_VALUE_H_
#include "tensorflow/lite/kernels/register.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_EXPECTED_VALUE();
} // namespace custom
} // namespace ops
} // namespace tflite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_EXPECTED_VALUE_H_
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
#include <algorithm>
#include <cmath>
#include "tensorflow/lite/context.h"
namespace tflite {
// Returns the original (dequantized) value of 8bit value.
inline float PodDequantizeValue(const TfLiteTensor& tensor, uint8_t value) {
const int32_t zero_point = tensor.params.zero_point;
const float scale = tensor.params.scale;
return (static_cast<int32_t>(value) - zero_point) * scale;
}
// Returns the original (dequantized) value of the 'index'-th element of
// 'tensor.
inline float PodDequantize(const TfLiteTensor& tensor, int index) {
return PodDequantizeValue(tensor, tensor.data.uint8[index]);
}
// Quantizes 'value' to 8bit, given the quantization bias (zero_point) and
// factor (inverse_scale).
inline uint8_t PodQuantize(float value, int32_t zero_point,
float inverse_scale) {
const float integer_value_in_float = value * inverse_scale;
const float offset = (integer_value_in_float >= 0.0) ? 0.5f : -0.5f;
// NOTE(sfeuz): This assumes value * inverse_scale is within [INT_MIN,
// INT_MAX].
int32_t integer_value =
static_cast<int32_t>(integer_value_in_float + offset) + zero_point;
return static_cast<uint8_t>(std::max(std::min(255, integer_value), 0));
}
} // namespace tflite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/**
* Sequence String projection op used in PRADO.
*/
#include "tflite_ops/sequence_string_projection.h" // sequence_projection
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <map>
#include <memory>
#include <unordered_map>
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "tensorflow/lite/string_util.h"
#include "tf_ops/projection_normalizer_util.h" // sequence_projection
#include "tf_ops/projection_util.h" // sequence_projection
#include "tflite_ops/quantization_util.h" // sequence_projection
namespace tflite {
namespace ops {
namespace custom {
namespace sequence_string_projection {
/**
* This op referred to as Ternary Sequence String Projection op (TSP), tokenizes
* input text either on space or unicode boundary. Fingerprint for each token is
* computed using murmur hash and bit features are extracted from fingerprint
* that maps every 2 bits to the ternary output {-1, 0, 1}. This effectively
* turns a text input into a ternary rank 3 tensor (in 8bit/float format) of
* shape [1, max token length, requested number of features].
*
* Input:
* tensor[0]: Input message, string[num_batch]
* attribute[0]: feature size
* attribute[1]: vocabulary, a set of allowed characters in utf8 format.
* attribute[2]: split_on_space, a boolean specifying the tokenization method.
* attribute[3]: max_splits, maximum number of splits allowed during
* tokenization. When max_splits is set to -1, no limit on
* number of tokens is imposed. When it is set to a positive
* integer, number of tokens is truncated beyond that integer.
* An end of input token is always added after tokenization,
* hence the number of tokens is one more than the true number
* of tokens. As a result, the number of tokens returned by this
* op is not the same as absl::StrSplit.
* attribute[4]: word_novelty_bits, when set to a positive value less than 8,
* generates a word specific novelty feature in the last feature
* index.
* attribute[5]: doc_size_levels, when set to a positive value less than 17,
* generates a feature proportional to the logarithm of the
* number of tokens in the second to last feature index.
* attribute[6]: add_eos_tag, add an end of sequence tag to the output when
* true. Defaults to true.
* attribute[7]: add_bos_tag, add a begin of sequence tag to the output when
* true. Defaults to false.
* Output:
* tensor[0]: computed projections.
* float32[true number of tokens][feature size]
* true number of tokens is number of tokens + 1. (for end of
* sequence).
*/
namespace {
constexpr char kBeginToken[] = "<BOS>";
constexpr char kEndToken[] = "<EOS>";
constexpr int kInputMessage = 0;
constexpr int kOutputLabel = 0;
enum class BosTag { kGenerate, kNone };
enum class EosTag { kGenerate, kNone };
class ProjectionParams {
public:
ProjectionParams(int feature_size, const std::string& vocabulary,
int max_splits, bool split_on_space, int word_novelty_bits,
int doc_size_levels, BosTag add_bos_tag, EosTag add_eos_tag,
bool exclude_nonalphaspace_unicodes,
const std::string& token_separators,
bool normalize_repetition)
: feature_size_(feature_size),
unicode_handler_(vocabulary, exclude_nonalphaspace_unicodes),
hasher_(feature_size),
max_splits_(max_splits),
split_on_space_(split_on_space),
word_novelty_bits_(word_novelty_bits),
doc_size_levels_(doc_size_levels),
add_bos_tag_(add_bos_tag == BosTag::kGenerate),
add_eos_tag_(add_eos_tag == EosTag::kGenerate) {
assert(max_splits_ == -1 || max_splits_ > 0);
assert(word_novelty_bits >= 0 && word_novelty_bits <= 7);
if (word_novelty_bits_ != 0) {
assert(feature_size_ >= 1);
}
assert(doc_size_levels >= 0 && doc_size_levels <= 16);
if (doc_size_levels_ != 0) {
assert(feature_size_ >= 2);
}
word_novelty_offset_ = 2.0f / (1 << word_novelty_bits_);
if (!token_separators.empty() || normalize_repetition) {
projection_normalizer_.reset(
new ProjectionNormalizer(token_separators, normalize_repetition));
}
}
virtual ~ProjectionParams() {}
int FeatureSize() const { return feature_size_; }
bool WordNoveltyEnabled() const { return word_novelty_bits_ != 0; }
void WordNoveltyFeature(float* data, int word_count) const {
*data = std::min((word_count * word_novelty_offset_) - 1.0f, 1.0f);
}
void WordNoveltyFeature(uint8_t* data, int word_count) const {
float word_novelty_feature;
WordNoveltyFeature(&word_novelty_feature, word_count);
*data = PodQuantize(word_novelty_feature, 127.0f, 127);
}
bool DocSizeFeatureEnabled() const { return (doc_size_levels_ != 0); }
int BosToken() const { return add_bos_tag_ ? 1 : 0; }
int EosToken() const { return add_eos_tag_ ? 1 : 0; }
void DocSizeFeature(float* data, int num_tokens) {
float doc_size_feature =
(doc_size_levels_ != 0)
? std::log2(static_cast<float>(num_tokens)) / doc_size_levels_
: 0.0f;
*data = std::min(doc_size_feature, 1.0f) * 2.0f - 1.0f;
}
void DocSizeFeature(uint8_t* data, int num_tokens) {
float doc_size_feature;
DocSizeFeature(&doc_size_feature, num_tokens);
*data = PodQuantize(doc_size_feature, 127.0f, 127);
}
void Hash(const std::string& word, std::vector<uint64_t>* hash_codes) {
hasher_.GetHashCodes(word, hash_codes);
}
// Lower cases the input text and eliminates all unsupported
// unicodes in it if a vocabulary is provided.
std::string LowerCaseUTF8WithSupportedUnicodes(
std::pair<const char*, size_t> source) const {
return unicode_handler_.LowerCaseUTF8WithSupportedUnicodes(source);
}
// Splits the input text into a set of tokens. Uses space as the delimiter
// when split_on_space is True and unicode boundaries as the delimiter
// otherwise. When max_splits is set to -1, no limit on number of tokens is
// imposed. When it is set to a positive integer, number of tokens is
// truncated beyond that integer. An end of input token is always added after
// tokenization, hence the number of tokens is one more than the true number
// of tokens.
virtual TfLiteStatus PreprocessInput(TfLiteTensor* input_t,
TfLiteContext* context) {
if (input_t->bytes == 0) {
context->ReportError(context, "Empty input not supported.");
return kTfLiteError;
}
tflite::StringRef inputref = tflite::GetString(input_t, /*string_index=*/0);
if (projection_normalizer_ == nullptr) {
tokens_ = unicode_handler_.Tokenize(inputref.str, inputref.len,
split_on_space_, max_splits_);
} else {
normalized_input_ = projection_normalizer_->Normalize(
inputref.str, inputref.len, SIZE_MAX);
tokens_ = unicode_handler_.Tokenize(normalized_input_, split_on_space_,
max_splits_);
}
if (GetNumTokens() == 0 && !add_bos_tag_ && !add_eos_tag_) {
context->ReportError(context, "No tokens found.");
return kTfLiteError;
}
return kTfLiteOk;
}
int GetNumTokens() const { return tokens_.size(); }
const std::vector<std::pair<const char*, size_t>>& GetTokens() const {
return tokens_;
}
virtual std::string PreprocessToken(const std::string& word) { return word; }
private:
int feature_size_;
ProjectionUnicodeHandler unicode_handler_;
Hasher hasher_;
int max_splits_;
bool split_on_space_;
int word_novelty_bits_;
int doc_size_levels_;
bool add_bos_tag_;
bool add_eos_tag_;
float word_novelty_offset_;
std::string normalized_input_;
protected:
std::unique_ptr<ProjectionNormalizer> projection_normalizer_;
std::vector<std::pair<const char*, size_t>> tokens_;
};
class ProjectionParamsV2 : public ProjectionParams {
public:
ProjectionParamsV2(int feature_size, const std::string& vocabulary,
BosTag add_bos_tag, EosTag add_eos_tag,
bool normalize_repetition)
: ProjectionParams(feature_size, vocabulary, /*max_splits = */ -1,
/* split_on_space = */ true,
/*word_novelty_bits = */ 0, /*doc_size_levels = */ 0,
add_bos_tag, add_eos_tag,
/*exclude_nonalphaspace_unicodes = */ false,
/*token_separators = */ "", normalize_repetition) {}
~ProjectionParamsV2() override {}
TfLiteStatus PreprocessInput(TfLiteTensor* input_t,
TfLiteContext* context) override {
const TfLiteIntArray* const dims = input_t->dims;
const int num_tokens = tflite::GetStringCount(input_t);
if (num_tokens == 0) {
context->ReportError(context, "Empty input not supported.");
return kTfLiteError;
}
if (dims->size != 2) {
context->ReportError(
context, "Input tensor is expected to be rank 2, got rank %d.",
dims->size);
return kTfLiteError;
} else if (dims->data[0] != 1) {
context->ReportError(context,
"Input tensor batch size should be 1, got %d.",
dims->data[0]);
return kTfLiteError;
} else if (num_tokens != dims->data[1]) {
context->ReportError(context,
"Inconsistent number of input tokens %d != %d.",
num_tokens, dims->data[1]);
return kTfLiteError;
}
for (int i = 0; i < num_tokens; ++i) {
const tflite::StringRef strref = tflite::GetString(input_t, i);
tokens_.push_back(std::pair<const char*, size_t>(strref.str, strref.len));
}
return kTfLiteOk;
}
std::string PreprocessToken(const std::string& word) override {
return projection_normalizer_ ? projection_normalizer_->Normalize(
word.data(), word.length(), SIZE_MAX)
: word;
}
};
inline void SetTensorToDynamic(TfLiteTensor* tensor) {
if (tensor->allocation_type != kTfLiteDynamic) {
tensor->allocation_type = kTfLiteDynamic;
tensor->data.raw = nullptr;
}
}
// Determines whether tensor is dynamic. Note that a tensor can be non-const and
// not dynamic. This function specifically checks for a dynamic tensor.
inline bool IsDynamicTensor(const TfLiteTensor* tensor) {
return tensor->allocation_type == kTfLiteDynamic;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
const int word_novelty_bits =
m["word_novelty_bits"].IsNull() ? 0 : m["word_novelty_bits"].AsInt32();
const int doc_size_levels =
m["doc_size_levels"].IsNull() ? 0 : m["doc_size_levels"].AsInt32();
const bool add_bos_tag =
m["add_bos_tag"].IsNull() ? false : m["add_bos_tag"].AsBool();
const bool add_eos_tag =
m["add_eos_tag"].IsNull() ? true : m["add_eos_tag"].AsBool();
// Old models that use the op may not have this attribute set, for those
// models the default value of false will be used.
const bool exclude_nonalphaspace_unicodes =
m["exclude_nonalphaspace_unicodes"].IsNull()
? false
: m["exclude_nonalphaspace_unicodes"].AsBool();
const std::string token_separators =
m["token_separators"].IsNull() ? "" : m["token_separators"].ToString();
const bool normalize_repetition = m["normalize_repetition"].AsBool();
return new ProjectionParams(
m["feature_size"].AsInt32(), m["vocabulary"].AsString().str(),
m["max_splits"].AsInt32(), m["split_on_space"].AsBool(),
word_novelty_bits, doc_size_levels,
add_bos_tag ? BosTag::kGenerate : BosTag::kNone,
add_eos_tag ? EosTag::kGenerate : EosTag::kNone,
exclude_nonalphaspace_unicodes, token_separators, normalize_repetition);
}
void* InitV2(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
return new ProjectionParamsV2(
m["feature_size"].AsInt32(), m["vocabulary"].AsString().str(),
m["add_bos_tag"].AsBool() ? BosTag::kGenerate : BosTag::kNone,
m["add_eos_tag"].AsBool() ? EosTag::kGenerate : EosTag::kNone,
m["normalize_repetition"].AsBool());
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<ProjectionParams*>(buffer);
}
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputLabel]];
SetTensorToDynamic(output);
return kTfLiteOk;
}
constexpr int kHashCodeBits = 64;
constexpr int kMapBits = 2;
constexpr int kIncrement = kHashCodeBits / kMapBits;
template <typename T>
void TypedEval(const T* mapping_table, ProjectionParams* params, T* data) {
auto tokens = params->GetTokens();
std::vector<uint64_t> hash_codes;
std::unordered_map<uint64_t, int> word_counter;
T doc_size_feature = T{0};
if (params->DocSizeFeatureEnabled()) {
params->DocSizeFeature(&doc_size_feature, tokens.size());
}
const int num_tokens = tokens.size() + params->EosToken();
for (int j = -params->BosToken(), offset0 = 0; j < num_tokens; ++j) {
std::string word;
if (j < 0) {
word = kBeginToken;
} else if (j < tokens.size()) {
word = params->LowerCaseUTF8WithSupportedUnicodes(tokens[j]);
word = params->PreprocessToken(word);
} else {
word = kEndToken;
}
params->Hash(word, &hash_codes);
for (int hindex = 0, k = 0; hindex < hash_codes.size(); hindex++) {
auto hash = hash_codes[hindex];
for (int kmax = std::min(k + kIncrement, params->FeatureSize());
k < kmax;) {
data[offset0 + k++] = mapping_table[hash & ((1 << kMapBits) - 1)];
hash >>= kMapBits;
}
}
offset0 += params->FeatureSize();
if (params->WordNoveltyEnabled() && !hash_codes.empty()) {
params->WordNoveltyFeature(&data[offset0 - 1],
word_counter[hash_codes[0]]++);
}
if (params->DocSizeFeatureEnabled()) {
data[offset0 - 2] = doc_size_feature;
}
}
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<ProjectionParams*>(node->user_data);
TF_LITE_ENSURE_OK(
context,
params->PreprocessInput(
&context->tensors[node->inputs->data[kInputMessage]], context));
TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputLabel]];
if (IsDynamicTensor(output)) {
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = 1;
output_size->data[1] =
params->BosToken() + params->GetNumTokens() + params->EosToken();
output_size->data[2] = params->FeatureSize();
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
} else {
context->ReportError(context, "Output must by dynamic.");
return kTfLiteError;
}
if (output->type == kTfLiteUInt8) {
const uint8_t kMappingTable[1 << kMapBits] = {127, 255, 0, 127};
TypedEval(kMappingTable, params, output->data.uint8);
} else if (output->type == kTfLiteFloat32) {
const float kMappingTable[1 << kMapBits] = {0.0, 1.0, -1.0, 0.0};
TypedEval(kMappingTable, params, output->data.f);
} else {
context->ReportError(context, "Output type must be UInt8 or Float32.");
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace
} // namespace sequence_string_projection
const char kSequenceStringProjection[] = "SEQUENCE_STRING_PROJECTION";
// This op converts a list of strings to a sequence of features using hashing.
TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION() {
static TfLiteRegistration r = {
sequence_string_projection::Init, sequence_string_projection::Free,
sequence_string_projection::Resize, sequence_string_projection::Eval};
return &r;
}
const char kSequenceStringProjectionV2[] = "SEQUENCE_STRING_PROJECTION_V2";
// This op converts a sequence of tokens to a sequence of projected features
// using hashing.
TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION_V2() {
static TfLiteRegistration r = {
sequence_string_projection::InitV2, sequence_string_projection::Free,
sequence_string_projection::Resize, sequence_string_projection::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#include "tensorflow/lite/kernels/register.h"
namespace tflite {
namespace ops {
namespace custom {
extern const char kSequenceStringProjection[];
TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION();
extern const char kSequenceStringProjectionV2[];
TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION_V2();
} // namespace custom
} // namespace ops
} // namespace tflite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tflite_ops/sequence_string_projection.h" // sequence_projection
#include <vector>
#include "tflite_ops/tf_tflite_diff_test_util.h" // sequence_projection
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
namespace ops {
namespace custom {
namespace {
using ::testing::ElementsAreArray;
using ::tflite::testing::AttrValue;
using ::tflite::testing::FloatTensor;
using ::tflite::testing::IntTensor;
using ::tflite::testing::OpEquivTestCase;
using ::tflite::testing::StringTensor;
using ::tflite::testing::TensorflowTfLiteOpTest;
class SequenceStringProjectionModel : public SingleOpModel {
public:
explicit SequenceStringProjectionModel(
bool split_on_space, int max_splits, int word_novelty_bits,
int doc_size_levels, bool add_eos_tag, TensorType output_type,
const std::string& token_separators = "",
bool normalize_repetition = false) {
flexbuffers::Builder fbb;
fbb.Map([&] {
fbb.Int("feature_size", 4);
fbb.String("vocabulary", "abcdefghijklmnopqrstuvwxyz");
fbb.Int("word_novelty_bits", word_novelty_bits);
fbb.Int("doc_size_levels", doc_size_levels);
fbb.Int("max_splits", max_splits);
fbb.Bool("split_on_space", split_on_space);
fbb.Bool("add_eos_tag", add_eos_tag);
fbb.String("token_separators", token_separators);
fbb.Bool("normalize_repetition", normalize_repetition);
});
fbb.Finish();
output_ = AddOutput({output_type, {}});
SetCustomOp(kSequenceStringProjection, fbb.GetBuffer(),
Register_SEQUENCE_STRING_PROJECTION);
BuildInterpreter({GetShape(input_)});
}
void Invoke(const std::string& input) {
PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
SingleOpModel::Invoke();
}
TfLiteStatus InvokeFailable(const std::string& input) {
PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
return interpreter_->Invoke();
}
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
void CheckOutputTensorShape(const std::vector<int>& expected_shape) {
EXPECT_EQ(GetTensorShape(output_), expected_shape);
}
private:
int input_ = AddInput(TensorType_STRING);
int output_;
};
TEST(SequenceStringProjectionTest, RegularInputUint8) {
std::vector<std::pair<std::string, std::vector<uint8_t>>> testcase = {
{"hello", {127, 255, 255, 127, 127, 255, 127, 127}},
{"world", {127, 255, 127, 127, 127, 255, 127, 127}},
};
for (const auto& test : testcase) {
SequenceStringProjectionModel m(true, -1, 0, 0, true, TensorType_UINT8);
m.Invoke(test.first);
EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray(test.second));
}
}
TEST(SequenceStringProjectionTest, RegularInputUint8NoEOSTag) {
std::vector<std::pair<std::string, std::vector<uint8_t>>> testcase = {
{"hello", {127, 255, 255, 127}},
{"world", {127, 255, 127, 127}},
};
for (const auto& test : testcase) {
SequenceStringProjectionModel m(true, -1, 0, 0, false, TensorType_UINT8);
m.Invoke(test.first);
EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray(test.second));
}
}
TEST(SequenceStringProjectionTest, RegularInputUint8DocSize) {
std::vector<std::pair<std::string, std::vector<uint8_t>>> testcase = {
{"hello", {127, 255, 0, 127, 127, 255, 0, 127}},
{"world", {127, 255, 0, 127, 127, 255, 0, 127}},
};
for (const auto& test : testcase) {
SequenceStringProjectionModel m(true, -1, 0, 8, true, TensorType_UINT8);
m.Invoke(test.first);
EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray(test.second));
}
}
TEST(SequenceStringProjectionTest, RegularInputUint8DocSizeWordNovelty) {
std::vector<std::pair<std::string, std::vector<uint8_t>>> testcase = {
{"hello", {127, 255, 0, 0, 127, 255, 0, 0}},
{"world", {127, 255, 0, 0, 127, 255, 0, 0}},
};
for (const auto& test : testcase) {
SequenceStringProjectionModel m(true, -1, 4, 8, true, TensorType_UINT8);
m.Invoke(test.first);
EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray(test.second));
}
}
TEST(SequenceStringProjectionTest, RegularInputUint8WordNovelty) {
std::vector<std::pair<std::string, std::vector<uint8_t>>> testcase = {
{"hello", {127, 255, 255, 0, 127, 255, 127, 0}},
{"world", {127, 255, 127, 0, 127, 255, 127, 0}},
};
for (const auto& test : testcase) {
SequenceStringProjectionModel m(true, -1, 3, 0, true, TensorType_UINT8);
m.Invoke(test.first);
EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray(test.second));
}
}
TEST(SequenceStringProjectionTest, RegularInputFloat) {
std::vector<std::pair<std::string, std::vector<float>>> testcase = {
{"hello", {0, 1, 1, 0, 0, 1, 0, 0}},
{"world", {0, 1, 0, 0, 0, 1, 0, 0}},
};
for (const auto& test : testcase) {
SequenceStringProjectionModel m(true, -1, 0, 0, true, TensorType_FLOAT32);
m.Invoke(test.first);
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(test.second));
}
}
TEST(SequenceStringProjectionTest, RegularInputFloatNoEOSTag) {
std::vector<std::pair<std::string, std::vector<float>>> testcase = {
{"hello", {0, 1, 1, 0}},
{"world", {0, 1, 0, 0}},
};
for (const auto& test : testcase) {
SequenceStringProjectionModel m(true, -1, 0, 0, false, TensorType_FLOAT32);
m.Invoke(test.first);
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(test.second));
}
}
TEST(SequenceStringProjectionTest, RegularInputWithoutSplitOnSpace) {
std::vector<std::pair<std::string, std::vector<uint8_t>>> testcase = {
{"h", {127, 127, 255, 127, 127, 255, 127, 127}},
{"w", {255, 127, 255, 127, 127, 255, 127, 127}},
};
for (const auto& test : testcase) {
SequenceStringProjectionModel m(false, -1, 0, 0, true, TensorType_UINT8);
m.Invoke(test.first);
EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray(test.second));
}
}
TEST(SequenceStringProjectionTest, CheckSequenceLimit) {
std::string input;
for (int i = 0; i < 600; ++i) {
input += "hello world ";
}
SequenceStringProjectionModel m(true, 511, 0, 0, true, TensorType_UINT8);
m.Invoke(input);
const std::vector<int> expected_shape = {1, 512, 4};
m.CheckOutputTensorShape(expected_shape);
}
TEST(SequenceStringProjectionTest, CheckSequenceLimitBoundary) {
std::vector<std::pair<std::string, std::vector<int>>> testcase = {
{"hello", {1, 2, 4}},
{"hello ", {1, 2, 4}},
{"hello world", {1, 3, 4}},
{"hellow world ", {1, 3, 4}},
};
for (const auto& test : testcase) {
SequenceStringProjectionModel m(true, 2, 0, 0, true, TensorType_FLOAT32);
m.Invoke(test.first);
m.CheckOutputTensorShape(test.second);
}
}
TEST(SequenceStringProjectionTest, CheckSequenceLimitBoundaryWithoutSpace) {
std::vector<std::pair<std::string, std::vector<int>>> testcase = {
{"h", {1, 2, 4}},
{"he", {1, 3, 4}},
{"hel", {1, 3, 4}},
{"hello ", {1, 3, 4}},
};
for (const auto& test : testcase) {
SequenceStringProjectionModel m(false, 2, 0, 0, true, TensorType_UINT8);
m.Invoke(test.first);
m.CheckOutputTensorShape(test.second);
}
}
TEST(SequenceStringProjectionTest,
CheckSequenceLimitBoundaryWithoutSpaceNoEOS) {
std::vector<std::pair<std::string, std::vector<int>>> testcase = {
{"h", {1, 1, 4}},
{"he", {1, 2, 4}},
{"hel", {1, 2, 4}},
{"hello ", {1, 2, 4}},
};
for (const auto& test : testcase) {
SequenceStringProjectionModel m(false, 2, 0, 0, false, TensorType_UINT8);
m.Invoke(test.first);
m.CheckOutputTensorShape(test.second);
}
}
TEST(SequenceStringProjectionTest, TokenSeparators) {
// Separate the input using "!".
SequenceStringProjectionModel m1(true, -1, 0, 0, true, TensorType_UINT8, "!",
false);
m1.Invoke("great!!!");
auto output1 = m1.GetOutput<uint8_t>();
SequenceStringProjectionModel m2(true, -1, 0, 0, true, TensorType_UINT8, "!",
false);
m2.Invoke("great ! ! !");
auto output2 = m2.GetOutput<uint8_t>();
EXPECT_THAT(output1, ElementsAreArray(output2));
}
TEST(SequenceStringProjectionTest, EmptyInput) {
// Separate the input using "!".
SequenceStringProjectionModel no_eos(true, -1, 0, 0, false, TensorType_UINT8,
" ", false);
EXPECT_EQ(no_eos.InvokeFailable(" "), kTfLiteError);
EXPECT_EQ(no_eos.InvokeFailable(" "), kTfLiteError);
EXPECT_EQ(no_eos.InvokeFailable(""), kTfLiteError);
EXPECT_EQ(no_eos.InvokeFailable("hello"), kTfLiteOk);
SequenceStringProjectionModel with_eos(true, -1, 0, 0, true, TensorType_UINT8,
" ", false);
EXPECT_EQ(with_eos.InvokeFailable(" "), kTfLiteOk);
EXPECT_EQ(with_eos.InvokeFailable(" "), kTfLiteOk);
EXPECT_EQ(with_eos.InvokeFailable(""), kTfLiteOk);
EXPECT_EQ(with_eos.InvokeFailable("hello"), kTfLiteOk);
}
TEST(SequenceStringProjectionTest, NormalizeRepetition) {
// Normalize the repeated special tokens. Used for the emotion models.
SequenceStringProjectionModel m1(true, -1, 0, 0, true, TensorType_UINT8, "",
true);
m1.Invoke("hello..");
auto output1 = m1.GetOutput<uint8_t>();
SequenceStringProjectionModel m2(true, -1, 0, 0, true, TensorType_UINT8, "",
true);
m2.Invoke("hello.....");
auto output2 = m2.GetOutput<uint8_t>();
EXPECT_THAT(output1, ElementsAreArray(output2));
}
class SequenceStringProjectionTest : public TensorflowTfLiteOpTest {
std::function<TfLiteRegistration*()> TfLiteOpRegistration() override {
return ops::custom::Register_SEQUENCE_STRING_PROJECTION;
}
std::string TensorflowOpName() override { return "SequenceStringProjection"; }
};
TEST_P(SequenceStringProjectionTest, TensorflowTfLiteSame) {
RunTensorflowOp();
RunTfLiteOp();
CompareOpOutput();
}
std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
std::vector<OpEquivTestCase> test_cases;
constexpr float kScale = 2.0 / 255;
constexpr int kZero = 127;
{
OpEquivTestCase test_case;
test_case.test_name = "CheckEqualityNoBoSNoEoS";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(false);
test_case.attributes["add_bos_tag"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World 7153845&^$&^$&"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "CheckEqualityNoBoS";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(true);
test_case.attributes["add_bos_tag"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World 7153845&^$&^$&"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "CheckEqualityNoEoS";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(false);
test_case.attributes["add_bos_tag"] = AttrValue(true);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World 7153845&^$&^$&"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "CheckEquality";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(true);
test_case.attributes["add_bos_tag"] = AttrValue(true);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World 7153845&^$&^$&"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "SplitOnSpace";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(false);
test_case.attributes["max_splits"] = AttrValue(-1);
test_case.attributes["word_novelty_bits"] = AttrValue(0);
test_case.attributes["doc_size_levels"] = AttrValue(0);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World hello world"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NoSplitOnSpace";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["max_splits"] = AttrValue(-1);
test_case.attributes["word_novelty_bits"] = AttrValue(0);
test_case.attributes["doc_size_levels"] = AttrValue(0);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World hello world"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "SplitOnSpaceWithMax";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["max_splits"] = AttrValue(2);
test_case.attributes["word_novelty_bits"] = AttrValue(0);
test_case.attributes["doc_size_levels"] = AttrValue(0);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World hello world"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NoSplitOnSpaceWithMax";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(false);
test_case.attributes["max_splits"] = AttrValue(4);
test_case.attributes["word_novelty_bits"] = AttrValue(0);
test_case.attributes["doc_size_levels"] = AttrValue(0);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(StringTensor({1}, {"Hello World"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NoSplitOnSpaceWithDocSize";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(false);
test_case.attributes["max_splits"] = AttrValue(-1);
test_case.attributes["word_novelty_bits"] = AttrValue(0);
test_case.attributes["doc_size_levels"] = AttrValue(6);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World hello world"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "SplitOnSpaceWithDocSize";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["max_splits"] = AttrValue(-1);
test_case.attributes["word_novelty_bits"] = AttrValue(0);
test_case.attributes["doc_size_levels"] = AttrValue(7);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World hello world"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "SplitOnSpaceWithMaxSplitsAndDocSize";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["max_splits"] = AttrValue(2);
test_case.attributes["word_novelty_bits"] = AttrValue(0);
test_case.attributes["doc_size_levels"] = AttrValue(8);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World hello world"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NoSplitOnSpaceWithMaxSplitsAndDocSize";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(false);
test_case.attributes["max_splits"] = AttrValue(4);
test_case.attributes["word_novelty_bits"] = AttrValue(0);
test_case.attributes["doc_size_levels"] = AttrValue(4);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World hello world"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NoSplitOnSpaceWithWordNovelty";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(false);
test_case.attributes["max_splits"] = AttrValue(-1);
test_case.attributes["word_novelty_bits"] = AttrValue(2);
test_case.attributes["doc_size_levels"] = AttrValue(0);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World hello world"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "SplitOnSpaceWithWordNovelty";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["max_splits"] = AttrValue(-1);
test_case.attributes["word_novelty_bits"] = AttrValue(3);
test_case.attributes["doc_size_levels"] = AttrValue(0);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World hello world"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "SplitOnSpaceWithMaxSplitsAndWordNovelty";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["max_splits"] = AttrValue(2);
test_case.attributes["word_novelty_bits"] = AttrValue(4);
test_case.attributes["doc_size_levels"] = AttrValue(0);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World hello world"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NoSplitOnSpaceWithMaxSplitsAndWordNovelty";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(false);
test_case.attributes["max_splits"] = AttrValue(4);
test_case.attributes["word_novelty_bits"] = AttrValue(5);
test_case.attributes["doc_size_levels"] = AttrValue(0);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(StringTensor({1}, {"Hello World"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NoSplitOnSpaceWithWordNoveltyAndDocSize";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(false);
test_case.attributes["max_splits"] = AttrValue(-1);
test_case.attributes["word_novelty_bits"] = AttrValue(2);
test_case.attributes["doc_size_levels"] = AttrValue(8);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World hello world"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "SplitOnSpaceWithWordNoveltyAndDocSize";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["max_splits"] = AttrValue(-1);
test_case.attributes["word_novelty_bits"] = AttrValue(3);
test_case.attributes["doc_size_levels"] = AttrValue(6);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World hello world"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "SplitOnSpaceWithEverything";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["max_splits"] = AttrValue(2);
test_case.attributes["word_novelty_bits"] = AttrValue(5);
test_case.attributes["doc_size_levels"] = AttrValue(8);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World hello world"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NoSplitOnSpaceWithEverything";
test_case.attributes["vocabulary"] =
AttrValue("abcdefghijklmnopqrstuvwxyz");
test_case.attributes["split_on_space"] = AttrValue(false);
test_case.attributes["max_splits"] = AttrValue(4);
test_case.attributes["word_novelty_bits"] = AttrValue(7);
test_case.attributes["doc_size_levels"] = AttrValue(9);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(false);
test_case.input_tensors.push_back(StringTensor({1}, {"Hello World"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "SplitOnSpaceWithEverythingAndExclude";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["max_splits"] = AttrValue(2);
test_case.attributes["word_novelty_bits"] = AttrValue(5);
test_case.attributes["doc_size_levels"] = AttrValue(8);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(true);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World 7153845&^$&^$&"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NoSplitOnSpaceWithEverythingAndExclude";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["split_on_space"] = AttrValue(false);
test_case.attributes["max_splits"] = AttrValue(2);
test_case.attributes["word_novelty_bits"] = AttrValue(5);
test_case.attributes["doc_size_levels"] = AttrValue(8);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["exclude_nonalphaspace_unicodes"] = AttrValue(true);
test_case.input_tensors.push_back(
StringTensor({1}, {"Hello World 7153845&^$&^$&"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NormalizeRepetition";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(false);
test_case.attributes["add_bos_tag"] = AttrValue(false);
test_case.attributes["normalize_repetition"] = AttrValue(true);
test_case.input_tensors.push_back(StringTensor({1}, {"Hello World ..."}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "TokenSeparator";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(false);
test_case.attributes["add_bos_tag"] = AttrValue(false);
test_case.attributes["token_separators"] = AttrValue("-");
test_case.input_tensors.push_back(StringTensor({1}, {"Hello-World"}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
return test_cases;
}
INSTANTIATE_TEST_SUITE_P(
SequenceStringProjectionTests, SequenceStringProjectionTest,
::testing::ValuesIn(SequenceStringProjectionTestCases()));
class SequenceStringProjectionV2Model : public SingleOpModel {
public:
explicit SequenceStringProjectionV2Model(
std::vector<std::vector<int>> input_shapes) {
flexbuffers::Builder fbb;
fbb.Map([&] { fbb.Int("feature_size", 4); });
fbb.Finish();
input_ = AddInput(TensorType_STRING);
output_ = AddOutput({TensorType_UINT8, {}});
SetCustomOp(kSequenceStringProjectionV2, fbb.GetBuffer(),
Register_SEQUENCE_STRING_PROJECTION_V2);
BuildInterpreter(input_shapes);
}
void Invoke(const std::vector<std::string>& input, TfLiteStatus expected) {
PopulateStringTensor(input_, input);
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
ASSERT_EQ(interpreter_->Invoke(), expected);
}
private:
int input_;
int output_;
};
TEST(SequenceStringProjectionV2Test, RegularInputUint8EmptyNotSupported) {
// TFLite test infratructure currently does not let the error message to be
// extracted on failure. As a result just the return error code is tested
// as all other TFLite op handler tests. The error message each test invokes
// is captured in a comment though.
// ERROR: Empty input not supported.
SequenceStringProjectionV2Model m({{1, 0}});
m.Invoke({}, kTfLiteError);
}
TEST(SequenceStringProjectionV2Test, RegularInputUint8BatchNotSupported) {
// TFLite test infratructure currently does not let the error message to be
// extracted on failure. As a result just the return error code is tested
// as all other TFLite op handler tests. The error message each test invokes
// is captured in a comment though.
// ERROR: Input tensor batch size should be 1, got 2.
SequenceStringProjectionV2Model m({{2, 1}});
m.Invoke({"hello", "world"}, kTfLiteError);
}
TEST(SequenceStringProjectionV2Test, RegularInputUint8RankNot2NotSupported) {
// TFLite test infratructure currently does not let the error message to be
// extracted on failure. As a result just the return error code is tested
// as all other TFLite op handler tests. The error message each test invokes
// is captured in a comment though.
// ERROR: Input tensor is expected to be rank 2, got rank 3.
SequenceStringProjectionV2Model m({{2, 1, 1}});
m.Invoke({"hello", "world"}, kTfLiteError);
}
TEST(SequenceStringProjectionV2Test, RegularInputUint8InconsistentInput) {
// TFLite test infratructure currently does not let the error message to be
// extracted on failure. As a result just the return error code is tested
// as all other TFLite op handler tests. The error message each test invokes
// is captured in a comment though.
// ERROR: Inconsistent number of input tokens 3 != 2.
SequenceStringProjectionV2Model m({{1, 2}});
m.Invoke({"hello", "world", "goodbye"}, kTfLiteError);
}
TEST(SequenceStringProjectionV2Test, RegularInputUint8) {
// OK
SequenceStringProjectionV2Model m({{1, 2}});
m.Invoke({"hello", "world"}, kTfLiteOk);
}
class SequenceStringProjectionV2Test : public TensorflowTfLiteOpTest {
std::function<TfLiteRegistration*()> TfLiteOpRegistration() override {
return ops::custom::Register_SEQUENCE_STRING_PROJECTION_V2;
}
std::string TensorflowOpName() override {
return "SequenceStringProjectionV2";
}
};
TEST_P(SequenceStringProjectionV2Test, TensorflowTfLiteSame) {
RunTensorflowOp();
RunTfLiteOp();
CompareOpOutput();
}
std::vector<OpEquivTestCase> SequenceStringProjectionV2TestCases() {
std::vector<OpEquivTestCase> test_cases;
constexpr float kScale = 2.0 / 255;
constexpr int kZero = 127;
{
OpEquivTestCase test_case;
test_case.test_name = "CheckEqualityNoBoSNoEoS";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(false);
test_case.attributes["add_bos_tag"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1, 5}, {"Hello", "World", "7153845", "&^$&", "^$&"}));
test_case.input_tensors.push_back(IntTensor({1}, {5}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "CheckEqualityNoBoS";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(true);
test_case.attributes["add_bos_tag"] = AttrValue(false);
test_case.input_tensors.push_back(
StringTensor({1, 4}, {"Hello", "World", "7153845", "&^$&^$&"}));
test_case.input_tensors.push_back(IntTensor({1}, {4}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "CheckEqualityNoEoS";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(false);
test_case.attributes["add_bos_tag"] = AttrValue(true);
test_case.input_tensors.push_back(
StringTensor({1, 3}, {"Hello", "World", "7153845&^$&^$&"}));
test_case.input_tensors.push_back(IntTensor({1}, {3}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "CheckEquality";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(true);
test_case.attributes["add_bos_tag"] = AttrValue(true);
test_case.input_tensors.push_back(
StringTensor({1, 3}, {"Hello", "Worldddd", "7153845&^$&^$&"}));
test_case.input_tensors.push_back(IntTensor({1}, {3}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NormalizeRepetition";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(false);
test_case.attributes["add_bos_tag"] = AttrValue(false);
test_case.attributes["normalize_repetition"] = AttrValue(true);
test_case.input_tensors.push_back(
StringTensor({1, 6}, {"Hello", "World", "...", "..", ".", "...."}));
test_case.input_tensors.push_back(IntTensor({1}, {6}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
return test_cases;
}
INSTANTIATE_TEST_SUITE_P(
SequenceStringProjectionV2Tests, SequenceStringProjectionV2Test,
::testing::ValuesIn(SequenceStringProjectionV2TestCases()));
} // namespace
} // namespace custom
} // namespace ops
} // namespace tflite
int main(int argc, char** argv) {
// On Linux, add: absl::SetFlag(&FLAGS_logtostderr, true);
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tflite_ops/tf_tflite_diff_test_util.h" // sequence_projection
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tflite {
namespace testing {
using ::tensorflow::TensorProto;
using ::testing::FloatNear;
::tflite::TensorType TfTypeToTfLiteType(::tensorflow::DataType dtype) {
switch (dtype) {
case ::tensorflow::DT_FLOAT:
return TensorType_FLOAT32;
case ::tensorflow::DT_INT32:
return TensorType_INT32;
case ::tensorflow::DT_STRING:
return TensorType_STRING;
case ::tensorflow::DT_BOOL:
return TensorType_BOOL;
default:
LOG(FATAL) << "Unrecognized dtype: " << dtype;
}
}
void SetTensorProtoShape(const std::vector<int>& shape, TensorProto* tensor) {
auto* tensor_shape = tensor->mutable_tensor_shape();
for (int dim : shape) {
tensor_shape->add_dim()->set_size(dim);
}
}
TensorProto BoolTensor(const std::vector<int>& shape,
const std::vector<bool>& values) {
TensorProto tensor;
SetTensorProtoShape(shape, &tensor);
tensor.set_dtype(::tensorflow::DT_BOOL);
for (bool b : values) {
tensor.add_bool_val(b);
}
return tensor;
}
TensorProto IntTensor(const std::vector<int>& shape,
const std::vector<int>& values) {
TensorProto tensor;
tensor.set_dtype(::tensorflow::DT_INT32);
SetTensorProtoShape(shape, &tensor);
for (int i : values) {
tensor.add_int_val(i);
}
return tensor;
}
TensorProto FloatTensor(const std::vector<int>& shape,
const std::vector<float>& values) {
TensorProto tensor;
tensor.set_dtype(::tensorflow::DT_FLOAT);
SetTensorProtoShape(shape, &tensor);
for (float f : values) {
tensor.add_float_val(f);
}
return tensor;
}
TensorProto StringTensor(const std::vector<int>& shape,
const std::vector<std::string>& values) {
TensorProto tensor;
tensor.set_dtype(::tensorflow::DT_STRING);
SetTensorProtoShape(shape, &tensor);
for (const std::string& s : values) {
tensor.add_string_val(s);
}
return tensor;
}
void TensorflowTfLiteOpTest::SetUp() {
ConstructTensorflowOp();
ConstructTfLiteOp();
}
void TensorflowTfLiteOpTest::ConstructTensorflowOp() {
::tensorflow::NodeDefBuilder builder("test_op", TensorflowOpName());
for (const auto& attribute : GetParam().attributes) {
builder.Attr(attribute.first, attribute.second);
}
int index = 0;
for (const auto& input_tensor : GetParam().input_tensors) {
builder.Input("input", index, input_tensor.dtype());
index++;
}
TF_ASSERT_OK(builder.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
void TensorflowTfLiteOpTest::RunTensorflowOp() {
for (const auto& input_tensor : GetParam().input_tensors) {
switch (input_tensor.dtype()) {
case ::tensorflow::DT_FLOAT:
AddInput<float>(input_tensor.tensor_shape(),
[&input_tensor](int x) -> float {
return input_tensor.float_val(x);
});
break;
case ::tensorflow::DT_INT32:
AddInput<int>(
input_tensor.tensor_shape(),
[&input_tensor](int x) -> int { return input_tensor.int_val(x); });
break;
case ::tensorflow::DT_STRING:
AddInput<::tensorflow::tstring>(
input_tensor.tensor_shape(),
[&input_tensor](int x) -> ::tensorflow::tstring {
return input_tensor.string_val(x);
});
break;
case ::tensorflow::DT_BOOL:
AddInput<bool>(input_tensor.tensor_shape(),
[&input_tensor](int x) -> bool {
return input_tensor.bool_val(x);
});
break;
default:
LOG(FATAL) << "Unrecognized dtype: " << input_tensor.DebugString();
}
}
TF_ASSERT_OK(RunOpKernel());
}
std::vector<uint8_t> ConstructTfLiteCustomOptions(
absl::flat_hash_map<std::string, ::tensorflow::AttrValue> attributes,
const std::string& tensorflow_op) {
// Get the default attributes of the Tensorflow op.
const ::tensorflow::OpDef* tf_op_def;
TF_CHECK_OK(::tensorflow::OpRegistry::Global()->LookUpOpDef(tensorflow_op,
&tf_op_def));
for (const auto& tf_attribute : tf_op_def->attr()) {
if (tf_attribute.has_default_value() &&
!attributes.contains(tf_attribute.name())) {
attributes[tf_attribute.name()] = tf_attribute.default_value();
}
}
::flexbuffers::Builder fbb;
size_t map_start = fbb.StartMap();
for (const auto& attribute : attributes) {
switch (attribute.second.value_case()) {
case ::tensorflow::AttrValue::kS:
fbb.String(attribute.first.c_str(), attribute.second.s());
break;
case ::tensorflow::AttrValue::kI:
fbb.Int(attribute.first.c_str(), attribute.second.i());
break;
case ::tensorflow::AttrValue::kF:
fbb.Float(attribute.first.c_str(), attribute.second.f());
break;
case ::tensorflow::AttrValue::kB:
fbb.Bool(attribute.first.c_str(), attribute.second.b());
break;
case ::tensorflow::AttrValue::kList: {
int start = fbb.StartVector(attribute.first.c_str());
if (attribute.second.list().s_size() > 0) {
for (const std::string& s : attribute.second.list().s()) {
fbb.String(s);
}
} else if (attribute.second.list().i_size() > 0) {
for (int i : attribute.second.list().i()) {
fbb.Int(i);
}
} else if (attribute.second.list().f_size() > 0) {
for (float f : attribute.second.list().f()) {
fbb.Float(f);
}
} else if (attribute.second.list().b_size() > 0) {
for (bool b : attribute.second.list().b()) {
fbb.Bool(b);
}
}
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
break;
}
default:
LOG(FATAL) << "Unrecognized AttrValue type: "
<< attribute.second.DebugString();
}
}
fbb.EndMap(map_start);
fbb.Finish();
return std::vector<uint8_t>(fbb.GetBuffer());
}
void TensorflowTfLiteOpTest::ConstructTfLiteOp() {
std::vector<std::vector<int>> input_shapes;
for (const auto& input_tensor : GetParam().input_tensors) {
std::vector<int> shape;
for (const auto& dim : input_tensor.tensor_shape().dim()) {
shape.push_back(dim.size());
}
input_shapes.push_back(shape);
tflite_inputs_.push_back(
tflite_op_.AddInput(TfTypeToTfLiteType(input_tensor.dtype())));
}
for (const auto& output_tensor : GetParam().output_tensors) {
std::vector<int> shape;
for (const auto& dim : output_tensor.tensor.tensor_shape().dim()) {
shape.push_back(dim.size());
}
if (output_tensor.quantization_params.scale != 0.0) {
ASSERT_EQ(output_tensor.tensor.dtype(), ::tensorflow::DT_FLOAT)
<< "Quantization attempted on non-float tensor: "
<< output_tensor.tensor.DebugString();
// We can safely use as zero min and max, as they'll be ignored and
// the scale and zero_point will be used instead.
tflite_outputs_.push_back(tflite_op_.AddOutput(
{TensorType_UINT8, shape, /*min=*/0.0, /*max=*/0.0,
output_tensor.quantization_params.scale,
output_tensor.quantization_params.zero_point}));
} else {
tflite_outputs_.push_back(tflite_op_.AddOutput(
{TfTypeToTfLiteType(output_tensor.tensor.dtype()), shape}));
}
}
tflite_op_.SetCustomOp(
TfLiteOpName(),
ConstructTfLiteCustomOptions(GetParam().attributes, TensorflowOpName()),
TfLiteOpRegistration());
tflite_op_.BuildInterpreter(input_shapes);
}
void TensorflowTfLiteOpTest::RunTfLiteOp() {
int input_index = 0;
for (const auto& input_tensor : GetParam().input_tensors) {
switch (input_tensor.dtype()) {
case ::tensorflow::DT_FLOAT: {
std::vector<float> float_val(input_tensor.float_val().begin(),
input_tensor.float_val().end());
tflite_op_.PopulateTensor<float>(tflite_inputs_[input_index],
float_val);
break;
}
case ::tensorflow::DT_INT32: {
std::vector<int> int_val(input_tensor.int_val().begin(),
input_tensor.int_val().end());
tflite_op_.PopulateTensor<int>(tflite_inputs_[input_index], int_val);
break;
}
case ::tensorflow::DT_STRING: {
std::vector<std::string> string_val(input_tensor.string_val().begin(),
input_tensor.string_val().end());
tflite_op_.PopulateStringTensor(tflite_inputs_[input_index],
string_val);
break;
}
case ::tensorflow::DT_BOOL: {
std::vector<bool> bool_val(input_tensor.bool_val().begin(),
input_tensor.bool_val().end());
tflite_op_.PopulateTensor<bool>(tflite_inputs_[input_index], bool_val);
break;
}
default:
LOG(FATAL) << "Unrecognized dtype: " << input_tensor.DebugString();
}
input_index++;
}
tflite_op_.Invoke();
}
void TensorflowTfLiteOpTest::CompareOpOutput() {
for (int i = 0; i < tflite_outputs_.size(); i++) {
const ::tensorflow::Tensor& tf_output = *GetOutput(i);
std::vector<int> tflite_output_shape =
tflite_op_.GetTensorShape(tflite_outputs_[i]);
auto tf_output_shape = tf_output.shape();
EXPECT_EQ(tf_output_shape.dims(), tflite_output_shape.size());
for (int j = 0; j < tf_output_shape.dims(); j++) {
EXPECT_EQ(tf_output_shape.dim_size(j), tflite_output_shape[j]);
}
switch (tf_output.dtype()) {
case ::tensorflow::DT_FLOAT: {
auto tf_output_values = tf_output.flat<float>();
const auto& quantization_params =
GetParam().output_tensors[i].quantization_params;
if (quantization_params.scale != 0.0) {
auto tflite_output_values = Dequantize(
tflite_op_.ExtractVector<uint8_t>(tflite_outputs_[i]),
quantization_params.scale, quantization_params.zero_point);
for (int i = 0; i < tf_output_values.size(); i++) {
EXPECT_THAT(
tf_output_values(i),
FloatNear(tflite_output_values[i], quantization_params.scale));
}
} else {
auto tflite_output_values =
tflite_op_.ExtractVector<float>(tflite_outputs_[i]);
for (int i = 0; i < tf_output_values.size(); i++) {
EXPECT_EQ(tf_output_values(i), tflite_output_values[i]);
}
}
break;
}
case ::tensorflow::DT_INT32: {
auto tf_output_values = tf_output.flat<int>();
auto tflite_output_values =
tflite_op_.ExtractVector<int>(tflite_outputs_[i]);
for (int i = 0; i < tf_output_values.size(); i++) {
EXPECT_EQ(tf_output_values(i), tflite_output_values[i]);
}
break;
}
case ::tensorflow::DT_BOOL: {
auto tf_output_values = tf_output.flat<bool>();
auto tflite_output_values =
tflite_op_.ExtractVector<bool>(tflite_outputs_[i]);
for (int i = 0; i < tf_output_values.size(); i++) {
EXPECT_EQ(tf_output_values(i), tflite_output_values[i]);
}
break;
}
case ::tensorflow::DT_STRING: {
auto tf_output_values = tf_output.flat<::tensorflow::tstring>();
auto tflite_output_values =
tflite_op_.ExtractVector<std::string>(tflite_outputs_[i]);
for (int i = 0; i < tf_output_values.size(); i++) {
EXPECT_EQ(tf_output_values(i), tflite_output_values[i]);
}
break;
}
default:
LOG(FATAL) << "Unrecognized dtype: " << tf_output.dtype();
}
}
}
} // namespace testing
} // namespace tflite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Tests equivalence between TF and TFLite versions of an op.
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/lite/kernels/test_util.h"
namespace tflite {
namespace testing {
// Convenience constructors.
template <typename T>
::tensorflow::AttrValue AttrValue(T value) {
::tensorflow::AttrValue attr_value;
::tensorflow::SetAttrValue(value, &attr_value);
return attr_value;
}
::tensorflow::TensorProto BoolTensor(const std::vector<int>& shape,
const std::vector<bool>& values);
::tensorflow::TensorProto IntTensor(const std::vector<int>& shape,
const std::vector<int>& values);
::tensorflow::TensorProto FloatTensor(const std::vector<int>& shape,
const std::vector<float>& values);
::tensorflow::TensorProto StringTensor(const std::vector<int>& shape,
const std::vector<std::string>& values);
struct OutputTensor {
explicit OutputTensor(const ::tensorflow::TensorProto& tensor)
: tensor(tensor) {
quantization_params.scale = 0.0;
}
OutputTensor(const ::tensorflow::TensorProto& tensor, float scale,
int zero_point)
: tensor(tensor) {
quantization_params.scale = scale;
quantization_params.zero_point = zero_point;
}
::tensorflow::TensorProto tensor;
TfLiteQuantizationParams quantization_params;
};
struct OpEquivTestCase {
std::string test_name;
absl::flat_hash_map<std::string, ::tensorflow::AttrValue> attributes;
std::vector<::tensorflow::TensorProto> input_tensors;
std::vector<OutputTensor> output_tensors;
};
// Convert Tensorflow attributes into an equivalent TFLite flatbuffer. Adds the
// default attribute values from `tensorflow_op`, if they are not set in
// `attributes`.
std::vector<uint8_t> ConstructTfLiteCustomOptions(
absl::flat_hash_map<std::string, ::tensorflow::AttrValue> attributes,
const std::string& tensorflow_op);
// A test class that can be used to compare that a Tensorflow op and a
// TFLite op are producing the same output.
//
// To use:
// 1) Sub-class TensorflowTfLiteOpTest.
// Define TfLiteOpRegistration() and TensorflowOpName().
//
// class NewOpEquivTest : public TensorflowTfLiteOpTest {
// std::function<TfLiteRegistration*()> TfLiteOpRegistration() override {
// return ::tflite::custom::Register_NEW_OP;
// }
// std::string TensorflowOpName() override { return "NewOp"; }
// };
//
// 2) Declare a TEST_P (parameterized test) to perform the comparison.
//
// TEST_P(NewOpEquivTest, Compare) {
// RunTensorflowOp();
// RunTfLiteOp();
// CompareOpOutput();
// }
//
// 3) Define your test cases.
//
// std::vector<OpEquivTestCase> NewEquivOpTestCases() {
// std::vector<OpEquivTestCase> test_cases;
// {
// OpEquivTestCase test_case;
// test_case.test_name = "Simple";
// test_case.attributes["int_attr"] = AttrValue(1);
// test_case.attributes["bool_attr"] = AttrValue(true);
// test_case.input_tensor.push_back(StringTensor({1, 2}, {"a", "b"}));
// test_case.output_tensors.emplace_back(FloatTensor({}, {}));
// test_cases.push_back(test_case);
// }
// return test_cases;
// }
//
// 4) Instantiate your tests.
//
// INSTANTIATE_TEST_SUITE_P(
// NewOpEquivTest,
// NewOpEquivTest,
// ::testing::ValuesIn(NewOpEquivTestCases()),
// ::expander::GetTestName());
class TensorflowTfLiteOpTest
: public ::tensorflow::OpsTestBase,
public ::testing::WithParamInterface<OpEquivTestCase> {
protected:
void SetUp() override;
virtual void ConstructTensorflowOp();
virtual void RunTensorflowOp();
virtual void ConstructTfLiteOp();
virtual void RunTfLiteOp();
virtual void CompareOpOutput();
virtual std::function<TfLiteRegistration*()> TfLiteOpRegistration() = 0;
virtual std::string TfLiteOpName() { return "TestOp"; }
virtual std::string TensorflowOpName() = 0;
private:
::tflite::SingleOpModel tflite_op_;
std::vector<int> tflite_inputs_;
std::vector<int> tflite_outputs_;
};
} // namespace testing
} // namespace tflite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
# Placeholder to make bazel treat it as a package.
"""Set up configurable Android SDK and NDK dependencies."""
def android_workspace():
# String for replacement in Bazel template.
# These will either be replaced by android_sdk_repository if various ENV
# variables are set when `local_config_android` repo_rule is run, or they
# will be replaced by noops otherwise.
MAYBE_ANDROID_SDK_REPOSITORY
MAYBE_ANDROID_NDK_REPOSITORY
"""Repository rule for Android SDK and NDK autoconfiguration.
`android_configure` depends on the following environment variables:
* `ANDROID_NDK_HOME`: Location of Android NDK root.
* `ANDROID_SDK_HOME`: Location of Android SDK root.
* `ANDROID_SDK_API_LEVEL`: Desired Android SDK API version.
* `ANDROID_NDK_API_LEVEL`: Desired Android NDK API version.
* `ANDROID_BUILD_TOOLS_VERSION`: Desired Android build tools version.
Writes Android SDK and NDK rules.
Add the following to your WORKSPACE FILE:
```python
android_configure(name = "local_config_android")
```
Args:
name: A unique name for this workspace rule.
"""
_ANDROID_NDK_HOME = "ANDROID_NDK_HOME"
_ANDROID_SDK_HOME = "ANDROID_SDK_HOME"
_ANDROID_NDK_API_VERSION = "ANDROID_NDK_API_LEVEL"
_ANDROID_SDK_API_VERSION = "ANDROID_SDK_API_LEVEL"
_ANDROID_BUILD_TOOLS_VERSION = "ANDROID_BUILD_TOOLS_VERSION"
_ANDROID_SDK_REPO_TEMPLATE = """
native.android_sdk_repository(
name="androidsdk",
path="%s",
api_level=%s,
build_tools_version="%s",
)
"""
_ANDROID_NDK_REPO_TEMPLATE = """
native.android_ndk_repository(
name="androidndk",
path="%s",
api_level=%s,
)
"""
def _android_autoconf_impl(repository_ctx):
"""Implementation of the android_autoconf repository rule."""
sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME)
sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION)
build_tools_version = repository_ctx.os.environ.get(
_ANDROID_BUILD_TOOLS_VERSION,
)
ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME)
ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION)
sdk_rule = ""
if all([sdk_home, sdk_api_level, build_tools_version]):
sdk_rule = _ANDROID_SDK_REPO_TEMPLATE % (
sdk_home,
sdk_api_level,
build_tools_version,
)
ndk_rule = ""
if all([ndk_home, ndk_api_level]):
ndk_rule = _ANDROID_NDK_REPO_TEMPLATE % (ndk_home, ndk_api_level)
if ndk_rule == "" and sdk_rule == "":
sdk_rule = "pass"
# TODO(xunkai): Add interactive configure script.
repository_ctx.template(
"BUILD",
Label("//third_party/android:android_configure.BUILD.tpl"),
)
repository_ctx.template(
"android.bzl",
Label("//third_party/android:android.bzl.tpl"),
substitutions = {
"MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule,
"MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule,
},
)
android_configure = repository_rule(
implementation = _android_autoconf_impl,
environ = [
_ANDROID_SDK_API_VERSION,
_ANDROID_NDK_API_VERSION,
_ANDROID_BUILD_TOOLS_VERSION,
_ANDROID_NDK_HOME,
_ANDROID_SDK_HOME,
],
)
# Description:
# Eigen is a C++ template library for linear algebra: vectors,
# matrices, and related algorithms.
licenses([
# Note: Eigen is an MPL2 library that includes GPL v3 and LGPL v2.1+ code.
# We've taken special care to not reference any restricted code.
"reciprocal", # MPL2
"notice", # Portions BSD
])
exports_files(["COPYING.MPL2"])
EIGEN_FILES = [
"Eigen/**",
"unsupported/Eigen/CXX11/**",
"unsupported/Eigen/FFT",
"unsupported/Eigen/KroneckerProduct",
"unsupported/Eigen/src/FFT/**",
"unsupported/Eigen/src/KroneckerProduct/**",
"unsupported/Eigen/MatrixFunctions",
"unsupported/Eigen/SpecialFunctions",
"unsupported/Eigen/src/MatrixFunctions/**",
"unsupported/Eigen/src/SpecialFunctions/**",
]
# Files known to be under MPL2 license.
EIGEN_MPL2_HEADER_FILES = glob(
EIGEN_FILES,
exclude = [
# Guarantees that any non-MPL2 file added to the list above will fail to
# compile.
"Eigen/src/Core/util/NonMPL2.h",
"Eigen/**/CMakeLists.txt",
],
)
cc_library(
name = "eigen",
hdrs = EIGEN_MPL2_HEADER_FILES,
defines = [
# This define (mostly) guarantees we don't link any problematic
# code. We use it, but we do not rely on it, as evidenced above.
"EIGEN_MPL2_ONLY",
"EIGEN_MAX_ALIGN_BYTES=64",
"EIGEN_HAS_TYPE_TRAITS=0",
],
includes = ["."],
visibility = ["//visibility:public"],
)
filegroup(
name = "eigen_header_files",
srcs = EIGEN_MPL2_HEADER_FILES,
visibility = ["//visibility:public"],
)
licenses(["notice"]) # MIT
exports_files(["COPYING"])
config_setting(
name = "windows",
values = {
"cpu": "x64_windows",
},
)
cc_library(
name = "farmhash",
srcs = ["src/farmhash.cc"],
hdrs = ["src/farmhash.h"],
# Disable __builtin_expect support on Windows
copts = select({
":windows": ["/DFARMHASH_OPTIONAL_BUILTIN_EXPECT"],
"//conditions:default": [],
}),
includes = ["src/."],
visibility = ["//visibility:public"],
)
# This empty BUILD file is required to make Bazel treat this directory as a package.
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE.txt"])
licenses(["notice"])
config_setting(
name = "freebsd",
values = {"cpu": "freebsd"},
)
config_setting(
name = "windows",
values = {"cpu": "x64_windows"},
)
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
# Public flatc library to compile flatbuffer files at runtime.
cc_library(
name = "flatbuffers",
hdrs = ["//:public_headers"],
linkstatic = 1,
strip_include_prefix = "/include",
visibility = ["//visibility:public"],
deps = ["//src:flatbuffers"],
)
# Public C++ headers for the Flatbuffers library.
filegroup(
name = "public_headers",
srcs = [
"include/flatbuffers/base.h",
"include/flatbuffers/code_generators.h",
"include/flatbuffers/flatbuffers.h",
"include/flatbuffers/flexbuffers.h",
"include/flatbuffers/hash.h",
"include/flatbuffers/idl.h",
"include/flatbuffers/minireflect.h",
"include/flatbuffers/reflection.h",
"include/flatbuffers/reflection_generated.h",
"include/flatbuffers/registry.h",
"include/flatbuffers/stl_emulation.h",
"include/flatbuffers/util.h",
],
visibility = ["//:__subpackages__"],
)
# Public flatc compiler library.
cc_library(
name = "flatc_library",
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
"@flatbuffers//src:flatc_library",
],
)
# Public flatc compiler.
cc_binary(
name = "flatc",
linkopts = select({
":freebsd": [
"-lm",
],
":windows": [],
"//conditions:default": [
"-lm",
"-ldl",
],
}),
visibility = ["//visibility:public"],
deps = [
"@flatbuffers//src:flatc",
],
)
filegroup(
name = "flatc_headers",
srcs = [
"include/flatbuffers/flatc.h",
],
visibility = ["//:__subpackages__"],
)
# Library used by flatbuffer_cc_library rules.
cc_library(
name = "runtime_cc",
hdrs = [
"include/flatbuffers/base.h",
"include/flatbuffers/flatbuffers.h",
"include/flatbuffers/flexbuffers.h",
"include/flatbuffers/stl_emulation.h",
"include/flatbuffers/util.h",
],
linkstatic = 1,
strip_include_prefix = "/include",
visibility = ["//visibility:public"],
)
filegroup(
name = "runtime_py_srcs",
srcs = [
"python/flatbuffers/__init__.py",
"python/flatbuffers/builder.py",
"python/flatbuffers/compat.py",
"python/flatbuffers/encode.py",
"python/flatbuffers/number_types.py",
"python/flatbuffers/packer.py",
"python/flatbuffers/table.py",
"python/flatbuffers/util.py",
],
)
py_library(
name = "runtime_py",
srcs = [":runtime_py_srcs"],
visibility = ["//visibility:public"],
)
filegroup(
name = "runtime_java_srcs",
srcs = glob(["java/com/google/flatbuffers/**/*.java"]),
)
java_library(
name = "runtime_java",
srcs = [":runtime_java_srcs"],
visibility = ["//visibility:public"],
)
android_library(
name = "runtime_android",
srcs = [":runtime_java_srcs"],
visibility = ["//visibility:public"],
)
"""BUILD rules for generating flatbuffer files."""
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
flatc_path = "@flatbuffers//:flatc"
zip_files = "@org_tflite_support//tensorflow_lite_support/tools:zip_files"
DEFAULT_INCLUDE_PATHS = [
"./",
"$(GENDIR)",
"$(BINDIR)",
]
DEFAULT_FLATC_ARGS = [
"--no-union-value-namespacing",
"--gen-object-api",
]
def flatbuffer_library_public(
name,
srcs,
outs,
language_flag,
out_prefix = "",
includes = [],
include_paths = [],
flatc_args = DEFAULT_FLATC_ARGS,
reflection_name = "",
reflection_visibility = None,
output_to_bindir = False):
"""Generates code files for reading/writing the given flatbuffers in the requested language using the public compiler.
Outs:
filegroup(name): all generated source files.
Fileset([reflection_name]): (Optional) all generated reflection binaries.
Args:
name: Rule name.
srcs: Source .fbs files. Sent in order to the compiler.
outs: Output files from flatc.
language_flag: Target language flag. One of [-c, -j, -js].
out_prefix: Prepend this path to the front of all generated files except on
single source targets. Usually is a directory name.
includes: Optional, list of filegroups of schemas that the srcs depend on.
include_paths: Optional, list of paths the includes files can be found in.
flatc_args: Optional, list of additional arguments to pass to flatc.
reflection_name: Optional, if set this will generate the flatbuffer
reflection binaries for the schemas.
reflection_visibility: The visibility of the generated reflection Fileset.
output_to_bindir: Passed to genrule for output to bin directory.
"""
include_paths_cmd = ["-I %s" % (s) for s in include_paths]
# '$(@D)' when given a single source target will give the appropriate
# directory. Appending 'out_prefix' is only necessary when given a build
# target with multiple sources.
output_directory = (
("-o $(@D)/%s" % (out_prefix)) if len(srcs) > 1 else ("-o $(@D)")
)
genrule_cmd = " ".join([
"for f in $(SRCS); do",
"$(location %s)" % (flatc_path),
" ".join(flatc_args),
" ".join(include_paths_cmd),
language_flag,
output_directory,
"$$f;",
"done",
])
native.genrule(
name = name,
srcs = srcs,
outs = outs,
output_to_bindir = output_to_bindir,
tools = includes + [flatc_path],
cmd = genrule_cmd,
message = "Generating flatbuffer files for %s:" % (name),
)
if reflection_name:
reflection_genrule_cmd = " ".join([
"for f in $(SRCS); do",
"$(location %s)" % (flatc_path),
"-b --schema",
" ".join(flatc_args),
" ".join(include_paths_cmd),
language_flag,
output_directory,
"$$f;",
"done",
])
reflection_outs = [
(out_prefix + "%s.bfbs") % (s.replace(".fbs", "").split("/")[-1])
for s in srcs
]
native.genrule(
name = "%s_srcs" % reflection_name,
srcs = srcs,
outs = reflection_outs,
output_to_bindir = output_to_bindir,
tools = includes + [flatc_path],
cmd = reflection_genrule_cmd,
message = "Generating flatbuffer reflection binary for %s:" % (name),
)
# TODO(b/114456773): Make bazel rules proper and supported by flatbuffer
# Have to comment this since FilesetEntry is not supported in bazel
# skylark.
# native.Fileset(
# name = reflection_name,
# out = "%s_out" % reflection_name,
# entries = [
# native.FilesetEntry(files = reflection_outs),
# ],
# visibility = reflection_visibility,
# )
def flatbuffer_cc_library(
name,
srcs,
srcs_filegroup_name = "",
out_prefix = "",
includes = [],
include_paths = [],
flatc_args = DEFAULT_FLATC_ARGS,
visibility = None,
srcs_filegroup_visibility = None,
gen_reflections = False):
'''A cc_library with the generated reader/writers for the given flatbuffer definitions.
Outs:
filegroup([name]_srcs): all generated .h files.
filegroup(srcs_filegroup_name if specified, or [name]_includes if not):
Other flatbuffer_cc_library's can pass this in for their `includes`
parameter, if they depend on the schemas in this library.
Fileset([name]_reflection): (Optional) all generated reflection binaries.
cc_library([name]): library with sources and flatbuffers deps.
Remarks:
** Because the genrule used to call flatc does not have any trivial way of
computing the output list of files transitively generated by includes and
--gen-includes (the default) being defined for flatc, the --gen-includes
flag will not work as expected. The way around this is to add a dependency
to the flatbuffer_cc_library defined alongside the flatc included Fileset.
For example you might define:
flatbuffer_cc_library(
name = "my_fbs",
srcs = [ "schemas/foo.fbs" ],
includes = [ "//third_party/bazz:bazz_fbs_includes" ],
)
In which foo.fbs includes a few files from the Fileset defined at
//third_party/bazz:bazz_fbs_includes. When compiling the library that
includes foo_generated.h, and therefore has my_fbs as a dependency, it
will fail to find any of the bazz *_generated.h files unless you also
add bazz's flatbuffer_cc_library to your own dependency list, e.g.:
cc_library(
name = "my_lib",
deps = [
":my_fbs",
"//third_party/bazz:bazz_fbs"
],
)
Happy dependent Flatbuffering!
Args:
name: Rule name.
srcs: Source .fbs files. Sent in order to the compiler.
srcs_filegroup_name: Name of the output filegroup that holds srcs. Pass this
filegroup into the `includes` parameter of any other
flatbuffer_cc_library that depends on this one's schemas.
out_prefix: Prepend this path to the front of all generated files. Usually
is a directory name.
includes: Optional, list of filegroups of schemas that the srcs depend on.
** SEE REMARKS BELOW **
include_paths: Optional, list of paths the includes files can be found in.
flatc_args: Optional list of additional arguments to pass to flatc
(e.g. --gen-mutable).
visibility: The visibility of the generated cc_library. By default, use the
default visibility of the project.
srcs_filegroup_visibility: The visibility of the generated srcs filegroup.
By default, use the value of the visibility parameter above.
gen_reflections: Optional, if true this will generate the flatbuffer
reflection binaries for the schemas.
'''
output_headers = [
(out_prefix + "%s_generated.h") % (s.replace(".fbs", "").split("/")[-1])
for s in srcs
]
reflection_name = "%s_reflection" % name if gen_reflections else ""
flatbuffer_library_public(
name = "%s_srcs" % (name),
srcs = srcs,
outs = output_headers,
language_flag = "-c",
out_prefix = out_prefix,
includes = includes,
include_paths = include_paths,
flatc_args = flatc_args,
reflection_name = reflection_name,
reflection_visibility = visibility,
)
native.cc_library(
name = name,
hdrs = output_headers,
srcs = output_headers,
features = [
"-parse_headers",
],
deps = [
"@flatbuffers//:runtime_cc",
],
includes = ["."],
linkstatic = 1,
visibility = visibility,
)
# A filegroup for the `srcs`. That is, all the schema files for this
# Flatbuffer set.
native.filegroup(
name = srcs_filegroup_name if srcs_filegroup_name else "%s_includes" % (name),
srcs = srcs,
visibility = srcs_filegroup_visibility if srcs_filegroup_visibility != None else visibility,
)
# Custom provider to track dependencies transitively.
FlatbufferInfo = provider(
fields = {
"transitive_srcs": "flatbuffer schema definitions.",
},
)
def _flatbuffer_schemas_aspect_impl(target, ctx):
_ignore = [target]
transitive_srcs = depset()
if hasattr(ctx.rule.attr, "deps"):
for dep in ctx.rule.attr.deps:
if FlatbufferInfo in dep:
transitive_srcs = depset(dep[FlatbufferInfo].transitive_srcs, transitive = [transitive_srcs])
if hasattr(ctx.rule.attr, "srcs"):
for src in ctx.rule.attr.srcs:
if FlatbufferInfo in src:
transitive_srcs = depset(src[FlatbufferInfo].transitive_srcs, transitive = [transitive_srcs])
for f in src.files:
if f.extension == "fbs":
transitive_srcs = depset([f], transitive = [transitive_srcs])
return [FlatbufferInfo(transitive_srcs = transitive_srcs)]
# An aspect that runs over all dependencies and transitively collects
# flatbuffer schema files.
_flatbuffer_schemas_aspect = aspect(
attr_aspects = [
"deps",
"srcs",
],
implementation = _flatbuffer_schemas_aspect_impl,
)
# Rule to invoke the flatbuffer compiler.
def _gen_flatbuffer_srcs_impl(ctx):
outputs = ctx.attr.outputs
include_paths = ctx.attr.include_paths
if ctx.attr.no_includes:
no_includes_statement = ["--no-includes"]
else:
no_includes_statement = []
# Need to generate all files in a directory.
if not outputs:
outputs = [ctx.actions.declare_directory("{}_all".format(ctx.attr.name))]
output_directory = outputs[0].path
else:
outputs = [ctx.actions.declare_file(output) for output in outputs]
output_directory = outputs[0].dirname
deps = depset(ctx.files.srcs + ctx.files.deps, transitive = [
dep[FlatbufferInfo].transitive_srcs
for dep in ctx.attr.deps
if FlatbufferInfo in dep
])
include_paths_cmd_line = []
for s in include_paths:
include_paths_cmd_line.extend(["-I", s])
for src in ctx.files.srcs:
ctx.actions.run(
inputs = deps,
outputs = outputs,
executable = ctx.executable._flatc,
arguments = [
ctx.attr.language_flag,
"-o",
output_directory,
# Allow for absolute imports and referencing of generated files.
"-I",
"./",
"-I",
ctx.genfiles_dir.path,
"-I",
ctx.bin_dir.path,
] + no_includes_statement +
include_paths_cmd_line + [
"--no-union-value-namespacing",
"--gen-object-api",
src.path,
],
progress_message = "Generating flatbuffer files for {}:".format(src),
)
return [
DefaultInfo(files = depset(outputs)),
]
_gen_flatbuffer_srcs = rule(
_gen_flatbuffer_srcs_impl,
attrs = {
"srcs": attr.label_list(
allow_files = [".fbs"],
mandatory = True,
),
"outputs": attr.string_list(
default = [],
mandatory = False,
),
"deps": attr.label_list(
default = [],
mandatory = False,
aspects = [_flatbuffer_schemas_aspect],
),
"include_paths": attr.string_list(
default = [],
mandatory = False,
),
"language_flag": attr.string(
mandatory = True,
),
"no_includes": attr.bool(
default = False,
mandatory = False,
),
"_flatc": attr.label(
default = Label("@flatbuffers//:flatc"),
executable = True,
cfg = "host",
),
},
output_to_genfiles = True,
)
def _concat_flatbuffer_py_srcs_impl(ctx):
# Merge all generated python files. The files are concatenated and the
# import statements are removed. Finally we import the flatbuffer runtime
# library.
command = "find '%s' -name '*.py' -exec cat {} + | sed '/import flatbuffers/d'"
command += " | sed '1s/^/import flatbuffers\\'$'\\n/' > %s"
ctx.actions.run_shell(
inputs = ctx.attr.deps[0].files,
outputs = [ctx.outputs.out],
command = command % (
ctx.attr.deps[0].files.to_list()[0].path,
ctx.outputs.out.path,
),
)
_concat_flatbuffer_py_srcs = rule(
_concat_flatbuffer_py_srcs_impl,
attrs = {
"deps": attr.label_list(mandatory = True),
},
output_to_genfiles = True,
outputs = {"out": "%{name}.py"},
)
def flatbuffer_py_library(
name,
srcs,
deps = [],
include_paths = []):
"""A py_library with the generated reader/writers for the given schema.
This rule assumes that the schema files define non-conflicting names, so that
they can be merged in a single file. This is e.g. the case if only a single
namespace is used.
The rule call the flatbuffer compiler for all schema files and merges the
generated python files into a single file that is wrapped in a py_library.
Args:
name: Rule name. (required)
srcs: List of source .fbs files. (required)
deps: List of dependencies.
include_paths: Optional, list of paths the includes files can be found in.
"""
all_srcs = "{}_srcs".format(name)
_gen_flatbuffer_srcs(
name = all_srcs,
srcs = srcs,
language_flag = "--python",
deps = deps,
include_paths = include_paths,
)
all_srcs_no_include = "{}_srcs_no_include".format(name)
_gen_flatbuffer_srcs(
name = all_srcs_no_include,
srcs = srcs,
language_flag = "--python",
deps = deps,
no_includes = True,
include_paths = include_paths,
)
concat_py_srcs = "{}_generated".format(name)
_concat_flatbuffer_py_srcs(
name = concat_py_srcs,
deps = [
":{}".format(all_srcs_no_include),
],
)
native.py_library(
name = name,
srcs = [
":{}".format(concat_py_srcs),
],
srcs_version = "PY2AND3",
deps = deps,
)
def flatbuffer_java_library(
name,
srcs,
custom_package = "",
package_prefix = "",
include_paths = DEFAULT_INCLUDE_PATHS,
flatc_args = DEFAULT_FLATC_ARGS,
visibility = None):
"""A java library with the generated reader/writers for the given flatbuffer definitions.
Args:
name: Rule name. (required)
srcs: List of source .fbs files including all includes. (required)
custom_package: Package name of generated Java files. If not specified
namespace in the schema files will be used. (optional)
package_prefix: like custom_package, but prefixes to the existing
namespace. (optional)
include_paths: List of paths that includes files can be found in. (optional)
flatc_args: List of additional arguments to pass to flatc. (optional)
visibility: Visibility setting for the java_library rule. (optional)
"""
out_srcjar = "java_%s_all.srcjar" % name
flatbuffer_java_srcjar(
name = "%s_srcjar" % name,
srcs = srcs,
out = out_srcjar,
custom_package = custom_package,
flatc_args = flatc_args,
include_paths = include_paths,
package_prefix = package_prefix,
)
native.filegroup(
name = "%s.srcjar" % name,
srcs = [out_srcjar],
)
native.java_library(
name = name,
srcs = [out_srcjar],
javacopts = ["-source 7 -target 7"],
deps = [
"@flatbuffers//:runtime_java",
],
visibility = visibility,
)
def flatbuffer_java_srcjar(
name,
srcs,
out,
custom_package = "",
package_prefix = "",
include_paths = DEFAULT_INCLUDE_PATHS,
flatc_args = DEFAULT_FLATC_ARGS):
"""Generate flatbuffer Java source files.
Args:
name: Rule name. (required)
srcs: List of source .fbs files including all includes. (required)
out: Output file name. (required)
custom_package: Package name of generated Java files. If not specified
namespace in the schema files will be used. (optional)
package_prefix: like custom_package, but prefixes to the existing
namespace. (optional)
include_paths: List of paths that includes files can be found in. (optional)
flatc_args: List of additional arguments to pass to flatc. (optional)
"""
command_fmt = """set -e
tmpdir=$(@D)
schemas=$$tmpdir/schemas
java_root=$$tmpdir/java
rm -rf $$schemas
rm -rf $$java_root
mkdir -p $$schemas
mkdir -p $$java_root
for src in $(SRCS); do
dest=$$schemas/$$src
rm -rf $$(dirname $$dest)
mkdir -p $$(dirname $$dest)
if [ -z "{custom_package}" ] && [ -z "{package_prefix}" ]; then
cp -f $$src $$dest
else
if [ -z "{package_prefix}" ]; then
sed -e "s/namespace\\s.*/namespace {custom_package};/" $$src > $$dest
else
sed -e "s/namespace \\([^;]\\+\\);/namespace {package_prefix}.\\1;/" $$src > $$dest
fi
fi
done
flatc_arg_I="-I $$tmpdir/schemas"
for include_path in {include_paths}; do
flatc_arg_I="$$flatc_arg_I -I $$schemas/$$include_path"
done
flatc_additional_args=
for arg in {flatc_args}; do
flatc_additional_args="$$flatc_additional_args $$arg"
done
for src in $(SRCS); do
$(location {flatc_path}) $$flatc_arg_I --java $$flatc_additional_args -o $$java_root $$schemas/$$src
done
$(location {zip_files}) -export_zip_path=$@ -file_directory=$$java_root
"""
genrule_cmd = command_fmt.format(
package_name = native.package_name(),
custom_package = custom_package,
package_prefix = package_prefix,
flatc_path = flatc_path,
zip_files = zip_files,
include_paths = " ".join(include_paths),
flatc_args = " ".join(flatc_args),
)
native.genrule(
name = name,
srcs = srcs,
outs = [out],
tools = [flatc_path, zip_files],
cmd = genrule_cmd,
)
def flatbuffer_android_library(
name,
srcs,
custom_package = "",
package_prefix = "",
include_paths = DEFAULT_INCLUDE_PATHS,
flatc_args = DEFAULT_FLATC_ARGS,
visibility = None):
"""An android_library with the generated reader/writers for the given flatbuffer definitions.
Args:
name: Rule name. (required)
srcs: List of source .fbs files including all includes. (required)
custom_package: Package name of generated Java files. If not specified
namespace in the schema files will be used. (optional)
package_prefix: like custom_package, but prefixes to the existing
namespace. (optional)
include_paths: List of paths that includes files can be found in. (optional)
flatc_args: List of additional arguments to pass to flatc. (optional)
visibility: Visibility setting for the android_library rule. (optional)
"""
out_srcjar = "android_%s_all.srcjar" % name
flatbuffer_java_srcjar(
name = "%s_srcjar" % name,
srcs = srcs,
out = out_srcjar,
custom_package = custom_package,
flatc_args = flatc_args,
include_paths = include_paths,
package_prefix = package_prefix,
)
native.filegroup(
name = "%s.srcjar" % name,
srcs = [out_srcjar],
)
# To support org.checkerframework.dataflow.qual.Pure.
checkerframework_annotations = [
"@org_checkerframework_qual",
] if "--java-checkerframework" in flatc_args else []
android_library(
name = name,
srcs = [out_srcjar],
javacopts = ["-source 7 -target 7"],
visibility = visibility,
deps = [
"@flatbuffers//:runtime_android",
] + checkerframework_annotations,
)
"""Loads the Flatbuffers library, used by TF Lite."""
load("//third_party:repo.bzl", "third_party_http_archive")
def repo():
third_party_http_archive(
name = "flatbuffers",
strip_prefix = "flatbuffers-1.12.0",
sha256 = "62f2223fb9181d1d6338451375628975775f7522185266cd5296571ac152bc45",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.12.0.tar.gz",
"https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz",
],
build_file = "//third_party/flatbuffers:BUILD.bazel",
delete = ["build_defs.bzl"],
link_files = {
"//third_party/flatbuffers:build_defs.bzl": "build_defs.bzl",
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment