Unverified Commit 20cc2190 authored by pyoung2778's avatar pyoung2778 Committed by GitHub
Browse files

Check in seq_flow_lite (#10750)

parent fdecf385
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from layers import base_layers from layers import base_layers # import seq_flow_lite module
from layers import transformer_layers from layers import transformer_layers # import seq_flow_lite module
class Model(tf.keras.layers.Layer): class Model(tf.keras.layers.Layer):
......
...@@ -20,12 +20,12 @@ from absl import logging ...@@ -20,12 +20,12 @@ from absl import logging
from tensor2tensor.utils import beam_search from tensor2tensor.utils import beam_search
import tensorflow as tf import tensorflow as tf
from layers import base_layers from layers import base_layers # import seq_flow_lite module
from layers import dense_layers from layers import dense_layers # import seq_flow_lite module
from layers import embedding_layers from layers import embedding_layers # import seq_flow_lite module
from layers import normalization_layers from layers import normalization_layers # import seq_flow_lite module
from layers import quantization_layers from layers import quantization_layers # import seq_flow_lite module
from layers import transformer_layers from layers import transformer_layers # import seq_flow_lite module
class TransformerUniformAttnDecoder(base_layers.BaseLayer): class TransformerUniformAttnDecoder(base_layers.BaseLayer):
......
...@@ -11,20 +11,23 @@ package( ...@@ -11,20 +11,23 @@ package(
) )
cc_library( cc_library(
name = "sequence_string_projection_op", name = "projection_normalizer_util",
srcs = [ srcs = ["projection_normalizer_util.cc"],
"sequence_string_projection.cc", hdrs = ["projection_normalizer_util.h"],
deps = [
":projection_util",
"@utf_archive//:utf",
], ],
)
cc_library(
name = "projection_tokenizer_util",
srcs = ["projection_tokenizer_util.cc"],
hdrs = ["projection_tokenizer_util.h"],
deps = [ deps = [
":projection_normalizer_util",
":projection_tokenizer_util",
":projection_util", ":projection_util",
":text_distorter", "@utf_archive//:utf",
"@com_google_absl//absl/container:flat_hash_map",
"@tensorflow_includes//:includes",
"@tensorflow_solib//:framework_lib",
], ],
alwayslink = 1,
) )
cc_library( cc_library(
...@@ -37,22 +40,46 @@ cc_library( ...@@ -37,22 +40,46 @@ cc_library(
) )
cc_library( cc_library(
name = "projection_tokenizer_util", name = "skipgram_finder",
srcs = ["projection_tokenizer_util.cc"], srcs = ["skipgram_finder.cc"],
hdrs = ["projection_tokenizer_util.h"], hdrs = ["skipgram_finder.h"],
deps = [ deps = [
":projection_util", "@com_google_absl//absl/container:flat_hash_map",
"@utf_archive//:utf", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@icu4c//:icu4c",
],
)
cc_test(
name = "skipgram_finder_test",
srcs = ["skipgram_finder_test.cc"],
deps = [
":skipgram_finder",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@icu4c//:icu4c",
], ],
) )
cc_library( cc_library(
name = "projection_normalizer_util", name = "subsequence_finder",
srcs = ["projection_normalizer_util.cc"], srcs = ["subsequence_finder.cc"],
hdrs = ["projection_normalizer_util.h"], hdrs = ["subsequence_finder.h"],
deps = [ deps = [
":projection_util", "@com_google_absl//absl/container:flat_hash_map",
"@utf_archive//:utf", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@icu4c//:icu4c",
],
)
cc_test(
name = "subsequence_finder_test",
srcs = ["subsequence_finder_test.cc"],
deps = [
":subsequence_finder",
"@com_google_googletest//:gtest_main",
], ],
) )
...@@ -67,6 +94,55 @@ cc_library( ...@@ -67,6 +94,55 @@ cc_library(
], ],
) )
cc_library(
name = "denylist_op",
srcs = ["denylist_op.cc"],
deps = [
":skipgram_finder",
":subsequence_finder",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@tensorflow_includes//:includes",
"@tensorflow_solib//:framework_lib",
],
alwayslink = 1,
)
gen_op_wrapper_py(
name = "denylist_op_py",
out = "denylist_op.py",
kernel_lib = ":denylist_op",
)
py_test(
name = "denylist_op_py_test",
srcs = ["denylist_op_test.py"],
main = "denylist_op_test.py",
python_version = "PY3",
srcs_version = "PY3",
deps = [
":denylist_op_py",
],
)
cc_library(
name = "sequence_string_projection_op",
srcs = [
"sequence_string_projection.cc",
],
deps = [
":projection_normalizer_util",
":projection_tokenizer_util",
":projection_util",
":text_distorter",
"@com_google_absl//absl/container:flat_hash_map",
"@tensorflow_includes//:includes",
"@tensorflow_solib//:framework_lib",
],
alwayslink = 1,
)
cc_test( cc_test(
name = "sequence_string_projection_test", name = "sequence_string_projection_test",
size = "small", size = "small",
...@@ -78,6 +154,12 @@ cc_test( ...@@ -78,6 +154,12 @@ cc_test(
], ],
) )
gen_op_wrapper_py(
name = "sequence_string_projection_op_py",
out = "sequence_string_projection_op.py",
kernel_lib = ":sequence_string_projection_op",
)
cc_library( cc_library(
name = "sequence_string_projection_op_v2", name = "sequence_string_projection_op_v2",
srcs = [ srcs = [
...@@ -111,12 +193,6 @@ gen_op_wrapper_py( ...@@ -111,12 +193,6 @@ gen_op_wrapper_py(
kernel_lib = ":sequence_string_projection_op_v2", kernel_lib = ":sequence_string_projection_op_v2",
) )
gen_op_wrapper_py(
name = "sequence_string_projection_op_py",
out = "sequence_string_projection_op.py",
kernel_lib = ":sequence_string_projection_op",
)
cc_library( cc_library(
name = "tf_custom_ops", name = "tf_custom_ops",
srcs = ["tf_custom_ops.cc"], srcs = ["tf_custom_ops.cc"],
......
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tf_ops/skipgram_finder.h" // seq_flow_lite
#include "tf_ops/subsequence_finder.h" // seq_flow_lite
namespace seq_flow_lite {
using ::tensorflow::OpKernel;
using ::tensorflow::OpKernelConstruction;
using ::tensorflow::OpKernelContext;
using ::tensorflow::Status;
using ::tensorflow::Tensor;
using ::tensorflow::TensorShape;
using ::tensorflow::errors::InvalidArgument;
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeHandle;
// Description of the outputs and attributes for the Denylist ops.
const char kDescription[] = R"(
output: A floating point tensor that contains a prediction vector for each
input string. The vector will either be:
* [1, 1, ..., 0, 0, ...] if no denylisted skipgrams are found.
(All negative categories are 1.0 and all positive categories are 0.0.)
* an indicator vector if any denylisted skipgrams are found.
(0.0 if no skipgrams belonging to the category were found and 1.0 otherwise)
max_skip_size: The maximum number of tokens that can be skipped when generating
skipgrams.
denylist: A string vector containing denylisted skipgrams.
denylist_category: An int32 vector containing the category of the corresponding
skipgram in the denylist.
categories: An int32 scalar. This is the total number of categories.
All categories in denylist_category must be in [0, categories).
negative_categories: An int32 scalar. The total number of categories that
should be set if no entries in the denylist are triggered. These
negative categories are assumed to be [0, negative_categories).
)";
// The base class for all Denylist ops. It does two things:
// 1) It defines the output tensor of the op and it defines the attributes
// needed to specify the denylist and convert denylist categories into
// output vectors.
// 2) It defines a Compute() function. The compute function is responsible
// for filling in the output tensor, while the subclass is responsible
// for processing the input.
class DenylistOpBase : public OpKernel {
public:
explicit DenylistOpBase(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("categories", &categories_));
OP_REQUIRES_OK(context, context->GetAttr("negative_categories",
&negative_categories_));
OP_REQUIRES(context, categories_ > 0,
InvalidArgument("Number of categories (", categories_,
") must be positive."));
OP_REQUIRES(
context, negative_categories_ >= 0,
InvalidArgument("Number of negative_categories (", negative_categories_,
") must be non-negative."));
OP_REQUIRES(context, negative_categories_ < categories_,
InvalidArgument("Number of categories (", categories_,
") must be greater than the "
"number of negative_categories (",
negative_categories_, ")."));
OP_REQUIRES_OK(context, context->GetAttr("max_skip_size", &max_skip_size_));
OP_REQUIRES_OK(context, context->GetAttr("denylist", &denylist_));
OP_REQUIRES_OK(context,
context->GetAttr("denylist_category", &denylist_category_));
OP_REQUIRES(context, denylist_.size() == denylist_category_.size(),
InvalidArgument("denylist length (", denylist_.size(),
") != denylist_category length (",
denylist_category_.size(), ")"));
int max =
*std::max_element(denylist_category_.begin(), denylist_category_.end());
OP_REQUIRES(context, max < categories_,
InvalidArgument("max element of denylist_category (", max,
") >= categories (", categories_, ")"));
int min =
*std::min_element(denylist_category_.begin(), denylist_category_.end());
OP_REQUIRES(
context, min >= 0,
InvalidArgument("min element of denylist_category (", min, ") < 0"));
}
void Compute(OpKernelContext* context) override {
auto compute_context = InitializeComputeContext(context);
if (compute_context == nullptr) {
return;
}
auto context_cleaner = absl::MakeCleanup([this, compute_context] {
this->FinalizeComputeContext(compute_context);
});
Tensor* output_tensor;
TensorShape output_shape = InputStringsShape(compute_context);
output_shape.AddDim(categories_);
OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
&output_tensor));
auto output_values = output_tensor->flat<float>();
for (int i = 0; i < NumInputStrings(compute_context); i++) {
auto category = GetCategories(i, compute_context);
int base_index = i * categories_;
if (category.empty()) {
for (int j = 0; j < categories_; j++) {
output_values(base_index + j) = j < negative_categories_ ? 1.0 : 0.0;
}
} else {
for (int j = 0; j < categories_; j++) {
output_values(base_index + j) = category.contains(j) ? 1.0 : 0.0;
}
}
}
}
protected:
int max_skip_size() { return max_skip_size_; }
int denylist_size() { return denylist_.size(); }
const std::string& denylist(int i) { return denylist_[i]; }
int32_t denylist_category(int i) { return denylist_category_[i]; }
private:
// Called at the beginning of Compute(). This function should process
// the input and return a context object that can be used to identify
// the denylist categories of each input string.
virtual void* InitializeComputeContext(OpKernelContext* context) = 0;
// Called at the end of Compute(). Frees the context object.
virtual void FinalizeComputeContext(void* context) = 0;
// Returns the shape of the input tensor, if it only consisted of strings.
// If the input tensor is strings, this is the shape of the input tensor.
// If the input tensor is tokens, this is the shape of the input tensor,
// minus the innermost dimension.
virtual TensorShape InputStringsShape(void* context) = 0;
// Returns the number of strings in the input tensor.
virtual int NumInputStrings(void* context) = 0;
// Returns the denylist categories of the index-th string.
virtual absl::flat_hash_set<int> GetCategories(int index, void* context) = 0;
int32_t categories_;
int32_t negative_categories_;
int max_skip_size_;
std::vector<std::string> denylist_;
std::vector<int32_t> denylist_category_;
};
// A base class for Denylist ops that expect a string tensor input.
class StringDenylistOp : public DenylistOpBase {
public:
explicit StringDenylistOp(OpKernelConstruction* context)
: DenylistOpBase(context) {}
private:
void* InitializeComputeContext(OpKernelContext* context) override {
const Tensor* input_tensor;
auto status = context->input("input", &input_tensor);
if (!status.ok()) {
context->CtxFailureWithWarning(__FILE__, __LINE__, status);
return nullptr;
}
return new ComputeContext(input_tensor);
}
void FinalizeComputeContext(void* context) override {
delete static_cast<ComputeContext*>(context);
}
TensorShape InputStringsShape(void* context) override {
return static_cast<ComputeContext*>(context)->input_tensor->shape();
}
int NumInputStrings(void* context) override {
return static_cast<ComputeContext*>(context)->input_tensor_values.size();
}
absl::flat_hash_set<int> GetCategories(int index, void* context) override {
return FindTerms(
static_cast<ComputeContext*>(context)->input_tensor_values(index));
}
struct ComputeContext {
ComputeContext(const Tensor* input_tensor)
: input_tensor(input_tensor),
input_tensor_values(input_tensor->flat<::tensorflow::tstring>()) {}
const Tensor* input_tensor;
::tensorflow::TTypes<::tensorflow::tstring>::ConstFlat input_tensor_values;
};
// Returns the set of denylist categories for the input string.
virtual absl::flat_hash_set<int> FindTerms(const std::string& input) = 0;
};
// A denylist op that uses the SkipgramFinder on string inputs.
class SkipgramDenylistOp : public StringDenylistOp {
public:
explicit SkipgramDenylistOp(OpKernelConstruction* context)
: StringDenylistOp(context) {
skipgram_finder_ = std::make_unique<SkipgramFinder>(max_skip_size());
for (int i = 0; i < denylist_size(); i++) {
skipgram_finder_->AddSkipgram(denylist(i), denylist_category(i));
}
}
private:
absl::flat_hash_set<int> FindTerms(const std::string& input) override {
return skipgram_finder_->FindSkipgrams(input);
}
std::unique_ptr<SkipgramFinder> skipgram_finder_;
};
REGISTER_KERNEL_BUILDER(
Name("SkipgramDenylist").Device(::tensorflow::DEVICE_CPU),
SkipgramDenylistOp);
// Shape inference function for Denylist ops with string inputs.
Status StringDenylistShapeFn(InferenceContext* context) {
int32_t categories;
TF_RETURN_IF_ERROR(context->GetAttr("categories", &categories));
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(context->Concatenate(
context->input(0), context->MakeShape({categories}), &output_shape));
context->set_output(0, output_shape);
return ::tensorflow::Status::OK();
}
REGISTER_OP("SkipgramDenylist")
.Input("input: string")
.Output("output: float")
.Attr("max_skip_size: int")
.Attr("denylist: list(string)")
.Attr("denylist_category: list(int)")
.Attr("categories: int")
.Attr("negative_categories: int")
.SetShapeFn(StringDenylistShapeFn)
.Doc(absl::StrCat("Generates dense prediction vectors for input strings "
"using a skipgram denylist.",
"\n\n", "input: A string tensor.", "\n\n", kDescription));
// A Denylist op that uses the SubsequenceFinder on string inputs.
class SubsequenceDenylistOp : public StringDenylistOp {
public:
explicit SubsequenceDenylistOp(OpKernelConstruction* context)
: StringDenylistOp(context) {
subsequence_finder_ = std::make_unique<SubsequenceFinder>(max_skip_size());
for (int i = 0; i < denylist_size(); i++) {
subsequence_finder_->AddSubsequence(denylist(i), denylist_category(i));
}
}
private:
absl::flat_hash_set<int> FindTerms(const std::string& input) override {
return subsequence_finder_->FindSubsequences(input);
}
std::unique_ptr<SubsequenceFinder> subsequence_finder_;
};
REGISTER_KERNEL_BUILDER(
Name("SubsequenceDenylist").Device(::tensorflow::DEVICE_CPU),
SubsequenceDenylistOp);
REGISTER_OP("SubsequenceDenylist")
.Input("input: string")
.Output("output: float")
.Attr("max_skip_size: int")
.Attr("denylist: list(string)")
.Attr("denylist_category: list(int)")
.Attr("categories: int")
.Attr("negative_categories: int")
.SetShapeFn(StringDenylistShapeFn)
.Doc(absl::StrCat("Generates dense prediction vectors for inputs using a "
"subsequence denylist.",
"\n\n", "input: A string tensor.", "\n\n", kDescription));
// A denylist op that uses the SkipgramFinder on tokenized string inputs.
// The inputs are a pair of tensors: a token tensor of type string and
// a token count tensor of type T.
template <typename T>
class TokenizedDenylistOp : public DenylistOpBase {
public:
explicit TokenizedDenylistOp(OpKernelConstruction* context)
: DenylistOpBase(context) {
skipgram_finder_ = std::make_unique<SkipgramFinder>(max_skip_size());
for (int i = 0; i < denylist_size(); i++) {
skipgram_finder_->AddSkipgram(denylist(i), denylist_category(i));
}
}
private:
void* InitializeComputeContext(OpKernelContext* context) override {
const Tensor* input_tensor;
{
auto status = context->input("input", &input_tensor);
if (!status.ok()) {
context->CtxFailureWithWarning(__FILE__, __LINE__, status);
return nullptr;
}
}
const Tensor* token_count_tensor;
{
auto status = context->input("token_count", &token_count_tensor);
if (!status.ok()) {
context->CtxFailureWithWarning(__FILE__, __LINE__, status);
return nullptr;
}
}
return new ComputeContext(input_tensor, token_count_tensor);
}
void FinalizeComputeContext(void* context) override {
delete static_cast<ComputeContext*>(context);
}
TensorShape InputStringsShape(void* context) override {
return static_cast<ComputeContext*>(context)->shape;
}
int NumInputStrings(void* context) override {
return static_cast<ComputeContext*>(context)->size;
}
absl::flat_hash_set<int> GetCategories(int index, void* x) override {
ComputeContext* context = static_cast<ComputeContext*>(x);
int64_t num_tokens = context->token_count_flat(index);
std::vector<absl::string_view> tokens;
tokens.reserve(num_tokens);
int64_t start = index * context->max_tokens;
for (int64_t i = start; i < start + num_tokens; i++) {
tokens.emplace_back(context->token_flat(i).data(),
context->token_flat(i).size());
}
return skipgram_finder_->FindSkipgrams(tokens);
}
struct ComputeContext {
ComputeContext(const Tensor* token_tensor, const Tensor* token_count_tensor)
: token_flat(token_tensor->flat<::tensorflow::tstring>()),
token_count_flat(token_count_tensor->flat<T>()) {
shape = token_tensor->shape();
max_tokens = shape.dim_size(shape.dims() - 1);
shape.RemoveLastDims(1);
size = 1;
for (int64_t i = 0; i < shape.dims(); i++) {
size = size * shape.dim_size(i);
}
}
const typename ::tensorflow::TTypes<::tensorflow::tstring>::ConstFlat
token_flat;
const typename ::tensorflow::TTypes<T>::ConstFlat token_count_flat;
TensorShape shape;
int64_t size;
int64_t max_tokens;
};
std::unique_ptr<SkipgramFinder> skipgram_finder_;
};
REGISTER_KERNEL_BUILDER(Name("TokenizedDenylist")
.Device(::tensorflow::DEVICE_CPU)
.TypeConstraint<int32_t>("Ttoken_count"),
TokenizedDenylistOp<int32_t>);
REGISTER_KERNEL_BUILDER(Name("TokenizedDenylist")
.Device(::tensorflow::DEVICE_CPU)
.TypeConstraint<int64_t>("Ttoken_count"),
TokenizedDenylistOp<int64_t>);
// Shape inference function for Denylist ops with tokenized string inputs.
Status TokenizedDenylistShapeFn(InferenceContext* context) {
int32_t categories;
TF_RETURN_IF_ERROR(context->GetAttr("categories", &categories));
ShapeHandle string_tensor_shape;
TF_RETURN_IF_ERROR(
context->Subshape(context->input(0), 0, -1, &string_tensor_shape));
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(context->Concatenate(
string_tensor_shape, context->MakeShape({categories}), &output_shape));
context->set_output(0, output_shape);
return ::tensorflow::Status::OK();
}
REGISTER_OP("TokenizedDenylist")
.Input("input: string")
.Input("token_count: Ttoken_count")
.Output("output: float")
.Attr("max_skip_size: int")
.Attr("denylist: list(string)")
.Attr("denylist_category: list(int)")
.Attr("categories: int")
.Attr("negative_categories: int")
.Attr("Ttoken_count: {int32, int64}")
.SetShapeFn(TokenizedDenylistShapeFn)
.Doc(absl::StrCat("Generates dense prediction vectors for tokens using a "
"skipgram denylist.",
"\n\n", "input: A string tensor of tokens.", "\n\n",
kDescription));
} // namespace seq_flow_lite
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.proto.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace seq_flow_lite {
namespace {
using ::tensorflow::DT_FLOAT;
using ::tensorflow::DT_INT32;
using ::tensorflow::DT_INT64;
using ::tensorflow::DT_STRING;
using ::tensorflow::NodeDefBuilder;
using ::tensorflow::OpsTestBase;
using ::tensorflow::Tensor;
using ::tensorflow::TensorShape;
using ::tensorflow::errors::InvalidArgument;
using ::tensorflow::test::ExpectTensorEqual;
using ::tensorflow::test::FillValues;
class SkipgramDenylistOpTest : public OpsTestBase {};
TEST_F(SkipgramDenylistOpTest, Correct) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SkipgramDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 2)
.Attr("negative_categories", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<::tensorflow::tstring>(TensorShape({2}),
{"q a q b q c q", "q a b q q c"});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output = *GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2}));
FillValues<float>(&expected, {0.0, 1.0, 1.0, 0.0});
ExpectTensorEqual<float>(expected, output);
}
TEST_F(SkipgramDenylistOpTest, Prefix) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SkipgramDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b.* c"})
.Attr("denylist_category", {1})
.Attr("categories", 2)
.Attr("negative_categories", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<::tensorflow::tstring>(TensorShape({2}),
{"q a q bq q c q", "q a bq q q c"});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output = *GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2}));
FillValues<float>(&expected, {0.0, 1.0, 1.0, 0.0});
ExpectTensorEqual<float>(expected, output);
}
TEST_F(SkipgramDenylistOpTest, ZeroCategories) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SkipgramDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 0)
.Attr("negative_categories", 0)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument("Number of categories (0) must be positive."));
}
TEST_F(SkipgramDenylistOpTest, NegativeCategoriesLessThanZero) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SkipgramDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 1)
.Attr("negative_categories", -1)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument(
"Number of negative_categories (-1) must be non-negative."));
}
TEST_F(SkipgramDenylistOpTest, CategoriesEqualNegativeCategories) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SkipgramDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 1)
.Attr("negative_categories", 1)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument("Number of categories (1) must be greater than the "
"number of negative_categories (1)."));
}
class SubsequenceDenylistOpTest : public OpsTestBase {};
TEST_F(SubsequenceDenylistOpTest, Correct) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SubsequenceDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 2)
.Attr("negative_categories", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<::tensorflow::tstring>(TensorShape({2}),
{"qaqbqcq", "qabqqc"});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output = *GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2}));
FillValues<float>(&expected, {0.0, 1.0, 1.0, 0.0});
ExpectTensorEqual<float>(expected, output);
}
TEST_F(SubsequenceDenylistOpTest, ZeroCategories) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SubsequenceDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 0)
.Attr("negative_categories", 0)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument("Number of categories (0) must be positive."));
}
TEST_F(SubsequenceDenylistOpTest, NegativeCategoriesLessThanZero) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SubsequenceDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 1)
.Attr("negative_categories", -1)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument(
"Number of negative_categories (-1) must be non-negative."));
}
TEST_F(SubsequenceDenylistOpTest, CategoriesEqualNegativeCategories) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "SubsequenceDenylist")
.Input({"input", 0, DT_STRING})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 1)
.Attr("negative_categories", 1)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument("Number of categories (1) must be greater than the "
"number of negative_categories (1)."));
}
class TokenizedDenylistOpTest : public OpsTestBase {};
TEST_F(TokenizedDenylistOpTest, CorrectInt64TokenCount) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "TokenizedDenylist")
.Input({"input", 0, DT_STRING})
.Input({"token_count", 0, DT_INT64})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 2)
.Attr("negative_categories", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<::tensorflow::tstring>(
TensorShape({2, 7}), {"q", "a", "q", "b", "q", "c", "q", //
"q", "a", "b", "q", "q", "c", ""});
AddInputFromArray<int64_t>(TensorShape({2}), {7, 6});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output = *GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2}));
FillValues<float>(&expected, {0.0, 1.0, 1.0, 0.0});
ExpectTensorEqual<float>(expected, output);
}
TEST_F(TokenizedDenylistOpTest, CorrectInt32TokenCount) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "TokenizedDenylist")
.Input({"input", 0, DT_STRING})
.Input({"token_count", 0, DT_INT32})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 2)
.Attr("negative_categories", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<::tensorflow::tstring>(
TensorShape({2, 7}), {"q", "a", "q", "b", "q", "c", "q", //
"q", "a", "b", "q", "q", "c", ""});
AddInputFromArray<int32_t>(TensorShape({2}), {7, 6});
TF_ASSERT_OK(RunOpKernel());
const Tensor& output = *GetOutput(0);
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2}));
FillValues<float>(&expected, {0.0, 1.0, 1.0, 0.0});
ExpectTensorEqual<float>(expected, output);
}
TEST_F(TokenizedDenylistOpTest, ZeroCategories) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "TokenizedDenylist")
.Input({"input", 0, DT_STRING})
.Input({"token_count", 0, DT_INT64})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 0)
.Attr("negative_categories", 0)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument("Number of categories (0) must be positive."));
}
TEST_F(TokenizedDenylistOpTest, NegativeCategoriesLessThanZero) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "TokenizedDenylist")
.Input({"input", 0, DT_STRING})
.Input({"token_count", 0, DT_INT64})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 1)
.Attr("negative_categories", -1)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument(
"Number of negative_categories (-1) must be non-negative."));
}
TEST_F(TokenizedDenylistOpTest, CategoriesEqualNegativeCategories) {
TF_ASSERT_OK(NodeDefBuilder("test_op", "TokenizedDenylist")
.Input({"input", 0, DT_STRING})
.Input({"token_count", 0, DT_INT64})
.Attr("max_skip_size", 1)
.Attr("denylist", {"a b c"})
.Attr("denylist_category", {1})
.Attr("categories", 1)
.Attr("negative_categories", 1)
.Finalize(node_def()));
EXPECT_EQ(InitOp(),
InvalidArgument("Number of categories (1) must be greater than the "
"number of negative_categories (1)."));
}
} // namespace
} // namespace seq_flow_lite
# Copyright 2022 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test denylist op and show example usage from python wrapper."""
import tensorflow as tf
from tf_ops import denylist_op # import seq_flow_lite module
class SkipgramDenylistTest(tf.test.TestCase):
def test_correct(self):
result = denylist_op.skipgram_denylist(
input=["q a q b q c q", "q a b q q c"],
max_skip_size=1,
denylist=["a b c"],
denylist_category=[1],
categories=2,
negative_categories=1)
self.assertAllEqual(result, [[0.0, 1.0], [1.0, 0.0]])
class SubsequenceDenylistTest(tf.test.TestCase):
def test_correct(self):
result = denylist_op.subsequence_denylist(
input=["qaqbqcq", "qabqqc"],
max_skip_size=1,
denylist=["a b c"],
denylist_category=[1],
categories=2,
negative_categories=1)
self.assertAllEqual(result, [[0.0, 1.0], [1.0, 0.0]])
class TokenizedDenylistTest(tf.test.TestCase):
def test_correct(self):
result = denylist_op.tokenized_denylist(
input=[["q", "a", "q", "b", "q", "c", "q"],
["q", "a", "b", "q", "q", "c", ""]],
token_count=[7, 6],
max_skip_size=1,
denylist=["a b c"],
denylist_category=[1],
categories=2,
negative_categories=1)
self.assertAllEqual(result, [[0.0, 1.0], [1.0, 0.0]])
if __name__ == "__main__":
tf.test.main()
...@@ -26,7 +26,7 @@ limitations under the License. ...@@ -26,7 +26,7 @@ limitations under the License.
bool IsDigit(const std::string& text) { bool IsDigit(const std::string& text) {
Rune rune; Rune rune;
for (size_t i = 0; i < text.length();) { for (size_t i = 0; i < text.length();) {
const int bytes_read = chartorune(&rune, const_cast<char *>(text.data())); const int bytes_read = chartorune(&rune, const_cast<char*>(text.data()));
if (rune == Runeerror || bytes_read == 0) break; if (rune == Runeerror || bytes_read == 0) break;
if (rune >= static_cast<Rune>('0') && rune <= static_cast<Rune>('9')) { if (rune >= static_cast<Rune>('0') && rune <= static_cast<Rune>('9')) {
return true; return true;
...@@ -98,6 +98,29 @@ std::string ContractToken(const char* input_ptr, size_t len, size_t num_chars) { ...@@ -98,6 +98,29 @@ std::string ContractToken(const char* input_ptr, size_t len, size_t num_chars) {
return token; return token;
} }
void NormalizeSpaces(std::string& input) {
// Whether to copy the next character if it's a space.
bool copy_space = false;
size_t j = 0;
for (size_t i = 0; i < input.length(); ++i) {
if (input[i] == ' ') {
if (!copy_space) continue;
copy_space = false;
} else {
copy_space = true;
}
if (j != i) {
input[j] = input[i];
}
++j;
}
if (j > 0 && input[j - 1] == ' ') {
--j;
}
input.resize(j);
}
void ProjectionNormalizer::InitializeSeparators(const std::string& separators) { void ProjectionNormalizer::InitializeSeparators(const std::string& separators) {
for (size_t i = 0; i < separators.length(); ++i) { for (size_t i = 0; i < separators.length(); ++i) {
if (separators[i] != ' ') { if (separators[i] != ' ') {
...@@ -150,9 +173,14 @@ std::string ProjectionNormalizer::Normalize(const char* input_ptr, size_t len, ...@@ -150,9 +173,14 @@ std::string ProjectionNormalizer::Normalize(const char* input_ptr, size_t len,
normalized = ContractToken(normalized.data(), normalized.length(), 3); normalized = ContractToken(normalized.data(), normalized.length(), 3);
} }
if (normalize_spaces_) {
NormalizeSpaces(normalized);
}
if (!separators_.empty()) { if (!separators_.empty()) {
// Add space around separators_. // Add space around separators_.
normalized = NormalizeInternal(normalized.data(), normalized.length()); normalized = NormalizeInternal(normalized.data(), normalized.length());
} }
return normalized; return normalized;
} }
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_ #ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_ #define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
...@@ -24,14 +24,17 @@ limitations under the License. ...@@ -24,14 +24,17 @@ limitations under the License.
// Normalizes the input with the given |separators| by adding a space before and // Normalizes the input with the given |separators| by adding a space before and
// after each separator. When |normalize_repetition| is true, it removes the // after each separator. When |normalize_repetition| is true, it removes the
// repeated characters (except numbers) which consecutively appeared more than // repeated characters (except numbers) which consecutively appeared more than
// twice in a word. // twice in a word. When |normalize_spaces| is true, it removes spaces from
// the beginning and ending of the input, as well as repeated spaces.
// Examples: arwwwww -> arww, good!!!!! -> good!!, hahaha => haha. // Examples: arwwwww -> arww, good!!!!! -> good!!, hahaha => haha.
class ProjectionNormalizer { class ProjectionNormalizer {
public: public:
explicit ProjectionNormalizer(const std::string& separators, explicit ProjectionNormalizer(const std::string& separators,
bool normalize_repetition = false) { bool normalize_repetition = false,
bool normalize_spaces = false)
: normalize_repetition_(normalize_repetition),
normalize_spaces_(normalize_spaces) {
InitializeSeparators(separators); InitializeSeparators(separators);
normalize_repetition_ = normalize_repetition;
} }
// Normalizes the repeated characters (except numbers) which consecutively // Normalizes the repeated characters (except numbers) which consecutively
...@@ -49,6 +52,7 @@ class ProjectionNormalizer { ...@@ -49,6 +52,7 @@ class ProjectionNormalizer {
std::unordered_set<char> separators_; std::unordered_set<char> separators_;
bool normalize_repetition_; bool normalize_repetition_;
bool normalize_spaces_;
}; };
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_ #endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_ #ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_ #define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
...@@ -55,4 +55,4 @@ class ProjectionTokenizer { ...@@ -55,4 +55,4 @@ class ProjectionTokenizer {
std::unordered_set<char> separators_; std::unordered_set<char> separators_;
}; };
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_ #endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_UTIL_H_ #ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_UTIL_H_ #define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_UTIL_H_
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -156,4 +156,4 @@ std::vector<std::string> SplitByChar(const char* input_ptr, size_t len, ...@@ -156,4 +156,4 @@ std::vector<std::string> SplitByChar(const char* input_ptr, size_t len,
std::string JoinPairsBySpace(std::vector<std::pair<const char*, size_t>> words); std::string JoinPairsBySpace(std::vector<std::pair<const char*, size_t>> words);
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_UTIL_H_ #endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_PROJECTION_UTIL_H_
...@@ -109,11 +109,14 @@ class SequenceStringProjectionOp : public OpKernel { ...@@ -109,11 +109,14 @@ class SequenceStringProjectionOp : public OpKernel {
bool normalize_repetition; bool normalize_repetition;
OP_REQUIRES_OK(context, context->GetAttr("normalize_repetition", OP_REQUIRES_OK(context, context->GetAttr("normalize_repetition",
&normalize_repetition)); &normalize_repetition));
bool normalize_spaces;
OP_REQUIRES_OK(context,
context->GetAttr("normalize_spaces", &normalize_spaces));
std::string separators; std::string separators;
OP_REQUIRES_OK(context, context->GetAttr("token_separators", &separators)); OP_REQUIRES_OK(context, context->GetAttr("token_separators", &separators));
if (!separators.empty() || normalize_repetition) { if (!separators.empty() || normalize_repetition || normalize_spaces) {
projection_normalizer_ = absl::make_unique<ProjectionNormalizer>( projection_normalizer_ = absl::make_unique<ProjectionNormalizer>(
separators, normalize_repetition); separators, normalize_repetition, normalize_spaces);
} }
OP_REQUIRES_OK(context, context->GetAttr("add_first_cap_feature", OP_REQUIRES_OK(context, context->GetAttr("add_first_cap_feature",
...@@ -326,6 +329,7 @@ REGISTER_OP("SequenceStringProjection") ...@@ -326,6 +329,7 @@ REGISTER_OP("SequenceStringProjection")
.Attr("split_on_space: bool = True") .Attr("split_on_space: bool = True")
.Attr("token_separators: string = ''") .Attr("token_separators: string = ''")
.Attr("normalize_repetition: bool = false") .Attr("normalize_repetition: bool = false")
.Attr("normalize_spaces: bool = false")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
DimensionHandle size; DimensionHandle size;
...@@ -384,6 +388,10 @@ Attribute(s): ...@@ -384,6 +388,10 @@ Attribute(s):
- add_all_caps_feature: Specifies the probability with which a feature to the - add_all_caps_feature: Specifies the probability with which a feature to the
resulting projection tensor that helps discriminate if the input token is resulting projection tensor that helps discriminate if the input token is
ALLCAPS will be added. ALLCAPS will be added.
- normalize_repetition: When true normalizes repetition in text tokens before
fingerprinting.
- normalize_spaces: When true strips leading and trailing spaces and removes
repeated spaces.
Output(s): Output(s):
- projection: Floating point tensor with ternary values of shape - projection: Floating point tensor with ternary values of shape
......
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/skipgram_finder.h" // seq_flow_lite
#include <cctype>
#include <deque>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "icu4c/source/common/unicode/utf8.h"
namespace seq_flow_lite {
namespace {
void PreprocessToken(std::string& token) {
char* s = const_cast<char*>(token.data());
int32_t size = token.size();
int32_t in = 0;
int32_t out = 0;
while (in < size) {
UChar32 c;
int32_t old_in = in;
U8_NEXT(s, in, size, c);
if (c < 0) {
break;
}
if (u_ispunct(c)) continue;
UChar32 cl = u_tolower(c);
// This is a hack, but there are exactly two unicode characters whose
// lowercase versions have longer UTF-8 encodings (0x23a to 0x2c65,
// 0x23e to 0x2c66). So, to avoid sizing issues, they're not lowercased.
if (U8_LENGTH(cl) > (in - old_in)) {
cl = c;
}
U8_APPEND_UNSAFE(s, out, cl);
}
size_t remaining = token.size() - in;
if (remaining > 0) {
memmove(s + out, s + in, remaining);
out += remaining;
}
token.resize(out);
}
} // namespace
void SkipgramFinder::AddSkipgram(absl::string_view skipgram, int category) {
std::vector<std::string> tokens = absl::StrSplit(skipgram, ' ');
// Store the skipgram in a trie-like structure that uses tokens as the
// edge labels, instead of characters. Each node represents a skipgram made
// from the tokens used to reach the node, and stores the categories the
// skipgram is associated with.
TrieNode* cur = &skipgram_trie_;
for (auto& token : tokens) {
if (absl::EndsWith(token, ".*")) {
token.resize(token.size() - 2);
PreprocessToken(token);
auto iter = cur->prefix_to_node.find(token);
if (iter != cur->prefix_to_node.end()) {
cur = &iter->second;
} else {
cur = &cur->prefix_to_node
.emplace(std::piecewise_construct,
std::forward_as_tuple(token), std::make_tuple<>())
.first->second;
}
continue;
}
PreprocessToken(token);
auto iter = cur->token_to_node.find(token);
if (iter != cur->token_to_node.end()) {
cur = &iter->second;
} else {
cur = &cur->token_to_node
.emplace(std::piecewise_construct,
std::forward_as_tuple(token), std::make_tuple<>())
.first->second;
}
}
cur->categories.insert(category);
}
absl::flat_hash_set<int> SkipgramFinder::FindSkipgrams(
absl::string_view input) const {
std::vector<std::string> tokens = absl::StrSplit(input, ' ');
std::vector<absl::string_view> sv_tokens;
sv_tokens.reserve(tokens.size());
for (auto& token : tokens) {
PreprocessToken(token);
sv_tokens.emplace_back(token.data(), token.size());
}
return FindSkipgrams(sv_tokens);
}
absl::flat_hash_set<int> SkipgramFinder::FindSkipgrams(
const std::vector<absl::string_view>& tokens) const {
absl::flat_hash_set<int> categories;
// Tracks skipgram prefixes and the index of their last token.
std::deque<std::pair<int, const TrieNode*>> indices_and_skipgrams;
for (int token_i = 0; token_i < tokens.size(); token_i++) {
const absl::string_view& token = tokens[token_i];
std::vector<absl::string_view> token_prefixes;
{
const char* s = token.data();
int32_t l = token.size();
int32_t n = 0;
while (n < l) {
int32_t n_old = n;
U8_FWD_1(s, n, l);
if (n == n_old) break;
token_prefixes.emplace_back(s, n);
}
}
// Drop any skipgrams prefixes which would skip more than `max_skip_size_`
// tokens between the end of the prefix and the current token.
while (!indices_and_skipgrams.empty()) {
if (indices_and_skipgrams.front().first + max_skip_size_ + 1 < token_i) {
indices_and_skipgrams.pop_front();
} else {
break;
}
}
// Check if we can form a valid skipgram prefix (or skipgram) by adding
// the current token to any of the existing skipgram prefixes, or
// if the current token is a valid skipgram prefix (or skipgram).
size_t size = indices_and_skipgrams.size();
for (size_t skipgram_i = 0; skipgram_i <= size; skipgram_i++) {
const auto& node = skipgram_i < size
? *indices_and_skipgrams[skipgram_i].second
: skipgram_trie_;
auto iter = node.token_to_node.find(token);
if (iter != node.token_to_node.end()) {
categories.insert(iter->second.categories.begin(),
iter->second.categories.end());
indices_and_skipgrams.push_back(std::make_pair(token_i, &iter->second));
}
for (auto token_prefix = token_prefixes.rbegin();
token_prefix != token_prefixes.rend(); token_prefix++) {
auto iter = node.prefix_to_node.find(*token_prefix);
if (iter != node.prefix_to_node.end()) {
categories.insert(iter->second.categories.begin(),
iter->second.categories.end());
indices_and_skipgrams.push_back(
std::make_pair(token_i, &iter->second));
}
}
}
}
return categories;
}
} // namespace seq_flow_lite
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SKIPGRAM_FINDER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SKIPGRAM_FINDER_H_
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
namespace seq_flow_lite {
// SkipgramFinder finds skipgrams in strings.
//
// To use: First, add skipgrams using AddSkipgram() - each skipgram is
// associated with some category. Then, call FindSkipgrams() on a string,
// which will return the set of categories of the skipgrams in the string.
//
// Both the skipgrams and the input strings will be tokenzied by splitting
// on spaces. Additionally, the tokens will be lowercased and have any
// trailing punctuation removed.
class SkipgramFinder {
public:
explicit SkipgramFinder(int max_skip_size) : max_skip_size_(max_skip_size) {}
// Adds a skipgram that SkipgramFinder should look for in input strings.
// Tokens may use the regex '.*' as a suffix.
void AddSkipgram(absl::string_view skipgram, int category);
// Find all of the skipgrams in `input`, and return their categories.
absl::flat_hash_set<int> FindSkipgrams(absl::string_view input) const;
// Find all of the skipgrams in `tokens`, and return their categories.
absl::flat_hash_set<int> FindSkipgrams(
const std::vector<absl::string_view>& tokens) const;
private:
struct TrieNode {
absl::flat_hash_set<int> categories;
// Maps tokens to the next node in the trie.
absl::flat_hash_map<std::string, TrieNode> token_to_node;
// Maps token prefixes (<prefix>.*) to the next node in the trie.
absl::flat_hash_map<std::string, TrieNode> prefix_to_node;
};
TrieNode skipgram_trie_;
int max_skip_size_;
};
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SKIPGRAM_FINDER_H_
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/skipgram_finder.h" // seq_flow_lite
#include <string>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "icu4c/source/common/unicode/utf8.h"
namespace seq_flow_lite {
namespace {
using ::testing::UnorderedElementsAreArray;
void TestFindSkipgrams(const SkipgramFinder& skipgram_finder,
const std::vector<std::string>& tokens,
const std::vector<int>& categories,
const std::vector<int>& token_categories) {
EXPECT_THAT(skipgram_finder.FindSkipgrams(absl::StrJoin(tokens, " ")),
UnorderedElementsAreArray(categories));
std::vector<absl::string_view> sv_tokens;
sv_tokens.reserve(tokens.size());
for (const auto& token : tokens) {
sv_tokens.emplace_back(token.data(), token.size());
}
EXPECT_THAT(skipgram_finder.FindSkipgrams(sv_tokens),
UnorderedElementsAreArray(token_categories));
}
// Test that u_tolower() will only increase the number of bytes in the
// UTF-8 encoding in two specific cases.
TEST(SkipgramFinderTest, UCharToLower) {
for (UChar32 c = 0; c < 0x10000; c++) {
if (c == 0x23a || c == 0x23e) continue;
UChar32 l = u_tolower(c);
EXPECT_GE(U8_LENGTH(c), U8_LENGTH(l)) << c << " lowercases to " << l;
}
}
TEST(SkipgramFinderTest, SingleExists) {
SkipgramFinder skipgram_finder(1);
std::string s("q r s");
skipgram_finder.AddSkipgram(s, 0);
TestFindSkipgrams(skipgram_finder, {"a", "q", "r", "s", "c"}, {0}, {0});
TestFindSkipgrams(skipgram_finder, {"a", "q", "xyz", "R!", "xy", "s", "c"},
{0}, {});
TestFindSkipgrams(skipgram_finder, {"a", "q", "r", "q", "R", "s.", "c"}, {0},
{});
}
TEST(SkipgramFinderTest, SingleNotExists) {
SkipgramFinder skipgram_finder(1);
std::string s("q r s");
skipgram_finder.AddSkipgram(s, 0);
TestFindSkipgrams(skipgram_finder, {"a", "q", "x", "x", "r", "x", "s", "c"},
{}, {});
TestFindSkipgrams(skipgram_finder, {"a", "q", "x", "r", "x", "c"}, {}, {});
TestFindSkipgrams(skipgram_finder, {"a", "r", "x", "s", "q", "c"}, {}, {});
}
TEST(SkipgramFinderTest, SinglePrefixExists) {
SkipgramFinder skipgram_finder(1);
std::string s("q.* r s");
skipgram_finder.AddSkipgram(s, 0);
TestFindSkipgrams(skipgram_finder, {"a", "qa", "r", "s", "c"}, {0}, {0});
TestFindSkipgrams(skipgram_finder, {"a", "q", "xyz", "R!", "xy", "s", "c"},
{0}, {});
TestFindSkipgrams(skipgram_finder, {"a", "qc", "r", "qd", "R", "s.", "c"},
{0}, {});
}
TEST(SkipgramFinderTest, SinglePrefixNotExists) {
SkipgramFinder skipgram_finder(1);
std::string s("q.* r s");
skipgram_finder.AddSkipgram(s, 0);
TestFindSkipgrams(skipgram_finder, {"a", "aq", "r", "s", "c"}, {}, {});
TestFindSkipgrams(skipgram_finder, {"a", "aqc", "xyz", "R!", "xy", "s", "c"},
{}, {});
TestFindSkipgrams(skipgram_finder, {"a", "q", "ar", "q", "aR", "s.", "c"}, {},
{});
}
TEST(SkipgramFinderTest, Punctuation) {
SkipgramFinder skipgram_finder(1);
std::string s("a-b-c def");
skipgram_finder.AddSkipgram(s, 0);
TestFindSkipgrams(skipgram_finder, {"q", "abc", "q", "d-e-f", "q"}, {0}, {});
TestFindSkipgrams(skipgram_finder, {"a", "'abc'", "q", "'def'", "q"}, {0},
{});
TestFindSkipgrams(skipgram_finder, {"q", "abc", "q", "def", "q"}, {0}, {0});
}
TEST(SkipgramFinderTest, HandlesMultibyteInput) {
SkipgramFinder skipgram_finder(1);
std::string s("hello\363\243\243\243!");
skipgram_finder.AddSkipgram(s, 0);
}
TEST(SkipgramFinderTest, Multiple) {
SkipgramFinder skipgram_finder(1);
std::string s1("a b c");
std::string s2("D e. F!");
std::string s3("ghi jkl mno");
std::string s4("S T U");
std::string s5("x. y, z!");
std::string s6("d.* e f");
skipgram_finder.AddSkipgram(s1, 0);
skipgram_finder.AddSkipgram(s2, 2);
skipgram_finder.AddSkipgram(s3, 4);
skipgram_finder.AddSkipgram(s4, 6);
skipgram_finder.AddSkipgram(s5, 8);
skipgram_finder.AddSkipgram(s6, 10);
TestFindSkipgrams(skipgram_finder, {"a", "d", "b", "e", "c", "f"}, {0, 2, 10},
{0, 2, 10});
TestFindSkipgrams(skipgram_finder, {"a", "dq", "b", "e", "c", "f"}, {0, 10},
{0, 10});
TestFindSkipgrams(skipgram_finder, {"a", "d", "b", "eq", "c", "f"}, {0}, {0});
TestFindSkipgrams(skipgram_finder, {"a", "ghi", "b", "jkl", "c", "x", "mno"},
{0}, {0});
TestFindSkipgrams(skipgram_finder, {"ghi", "d", "jkl", "e", "mno", "f"},
{2, 4, 10}, {2, 4, 10});
TestFindSkipgrams(skipgram_finder, {"s", "x", "t", "y", "u", "z"}, {6, 8},
{6, 8});
}
TEST(SkipgramFinderTest, UnicodeLowercase) {
// Check that the lowercase has a smaller UTF-8 encoding than the uppercase.
UChar32 cu;
U8_GET_UNSAFE("Ɦ", 0, cu);
UChar32 cl = u_tolower(cu);
EXPECT_GT(U8_LENGTH(cu), U8_LENGTH(cl));
SkipgramFinder skipgram_finder(1);
std::string s("Ɦ");
skipgram_finder.AddSkipgram(s, 0);
TestFindSkipgrams(skipgram_finder, {"Ɦ"}, {0}, {});
TestFindSkipgrams(skipgram_finder, {"ɦ"}, {0}, {0});
TestFindSkipgrams(skipgram_finder, {"h"}, {}, {});
}
} // namespace
} // namespace seq_flow_lite
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/subsequence_finder.h" // seq_flow_lite
#include <deque>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "icu4c/source/common/unicode/utf8.h"
namespace seq_flow_lite {
void SubsequenceFinder::AddSubsequence(absl::string_view subsequence,
int category) {
const char* s = subsequence.data();
int32_t length = subsequence.length();
int32_t n = 0;
TrieNode* trie = &subsequence_trie_;
bool new_word = true;
while (n < length) {
UChar32 c;
U8_NEXT(s, n, length, c);
if (c < 0) return;
c = u_tolower(c);
if (c == ' ') {
new_word = true;
} else if (!new_word) {
trie = &trie->continue_token[c];
} else {
trie = &trie->next_token[c];
new_word = false;
}
}
trie->categories.insert(category);
}
// Given a UChar32 and a trie node representing an in-progress subsequence,
// determine if we can use the UChar32 to continue the subsequence, and
// update `categories`, `next_tokens`, and `continue_tokens` if needed.
void SubsequenceFinder::ProcessUChar32AndTrieNode(
int index, UChar32 c,
const absl::flat_hash_map<UChar32, TrieNode>& token_map,
absl::flat_hash_set<int>* categories,
std::deque<std::pair<int, const TrieNode*>>* next_tokens,
std::vector<const TrieNode*>* continue_tokens) const {
auto iter = token_map.find(c);
if (iter != token_map.end()) {
categories->insert(iter->second.categories.begin(),
iter->second.categories.end());
if (!iter->second.continue_token.empty()) {
continue_tokens->push_back(&iter->second);
}
if (!iter->second.next_token.empty()) {
next_tokens->emplace_back(index, &iter->second);
}
}
}
absl::flat_hash_set<int> SubsequenceFinder::FindSubsequences(
absl::string_view input) const {
absl::flat_hash_set<int> categories;
// Tracks subsequences in progress that are starting the next token,
// as well as the index of their last character.
std::deque<std::pair<int, const TrieNode*>> next_tokens;
// Tracks subsequences in progress that are looking for the next character
// in their corrent token. `current_continue_tokens` is the current set of
// subsequences being processed, while `future_continue_tokens` is the set
// of subsequences to process for the next character.
std::vector<const TrieNode*> current_continue_tokens;
std::vector<const TrieNode*> future_continue_tokens;
const char* s = input.data();
int32_t length = input.length();
int32_t n = 0;
int index = 0;
while (n < length) {
UChar32 c;
U8_NEXT(s, n, length, c);
if (c < 0) return categories;
c = u_tolower(c);
// Drop any subsequences which would need to skip more than `max_skip_size_`
// characters between the end of their last token and the current character.
while (!next_tokens.empty()) {
if (next_tokens.front().first + max_skip_size_ + 1 < index) {
next_tokens.pop_front();
} else {
break;
}
}
// Check subsequences starting a new token.
size_t size = next_tokens.size();
for (size_t i = 0; i < size; i++) {
ProcessUChar32AndTrieNode(index, c, next_tokens[i].second->next_token,
&categories, &next_tokens,
&future_continue_tokens);
}
// Check subsequences continuing a token.
for (const TrieNode* continue_token : current_continue_tokens) {
ProcessUChar32AndTrieNode(index, c, continue_token->continue_token,
&categories, &next_tokens,
&future_continue_tokens);
}
// Check if we can start a new subsequence.
ProcessUChar32AndTrieNode(index, c, subsequence_trie_.next_token,
&categories, &next_tokens,
&future_continue_tokens);
current_continue_tokens.swap(future_continue_tokens);
future_continue_tokens.clear();
index++;
}
return categories;
}
} // namespace seq_flow_lite
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SUBSEQUENCE_FINDER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SUBSEQUENCE_FINDER_H_
#include <deque>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
namespace seq_flow_lite {
// SubsequenceFinder finds subsequences in UTF-8 strings.
//
// Specifically, given a subsequence t_1 t_2 ... t_n, we will check if a
// string matches '.*t_1.{0,N}t_2.{0,N} ... .{0,N}t_n.*', where N is the
// maximum skip size.
//
// To use: First, add subsequences using AddSubsequence() - each subsequence
// is associated with some category. Then call FindSubsequences() on a string,
// which will return the set of categories of the subsesequences in the string.
//
// The subsequences will be tokenized by splitting on spaces. Both subsequences
// and input strings will be normalized by lowercasing.
class SubsequenceFinder {
public:
explicit SubsequenceFinder(int max_skip_size)
: max_skip_size_(max_skip_size) {}
// Adds a subsequence that SubsequenceFinder should look for in input strings.
void AddSubsequence(absl::string_view subsequence, int category);
// Find all of the subsequences in `input`, and return their categories.
absl::flat_hash_set<int> FindSubsequences(absl::string_view input) const;
private:
// This trie tracks the next character needed to:
// * continue the current token
// * start the next token
struct TrieNode {
absl::flat_hash_set<int> categories;
absl::flat_hash_map<UChar32, TrieNode> continue_token;
absl::flat_hash_map<UChar32, TrieNode> next_token;
};
void ProcessUChar32AndTrieNode(
int index, UChar32 c,
const absl::flat_hash_map<UChar32, TrieNode>& token_map,
absl::flat_hash_set<int>* categories,
std::deque<std::pair<int, const TrieNode*>>* next_tokens,
std::vector<const TrieNode*>* continue_tokens) const;
TrieNode subsequence_trie_;
int max_skip_size_;
};
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SUBSEQUENCE_FINDER_H_
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/subsequence_finder.h" // seq_flow_lite
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace seq_flow_lite {
namespace {
using ::testing::UnorderedElementsAre;
TEST(SubsequenceFinderTest, SingleExists) {
SubsequenceFinder subsequence_finder(3);
subsequence_finder.AddSubsequence("ab cd", 0);
EXPECT_THAT(subsequence_finder.FindSubsequences("abcd"),
UnorderedElementsAre(0));
EXPECT_THAT(subsequence_finder.FindSubsequences("ab012cd"),
UnorderedElementsAre(0));
EXPECT_THAT(subsequence_finder.FindSubsequences("AB CD"),
UnorderedElementsAre(0));
}
TEST(SubsequenceFinderTest, SingleNotExists) {
SubsequenceFinder subsequence_finder(3);
subsequence_finder.AddSubsequence("ab cd", 0);
EXPECT_THAT(subsequence_finder.FindSubsequences("a bcd"),
UnorderedElementsAre());
EXPECT_THAT(subsequence_finder.FindSubsequences("ab0123cd"),
UnorderedElementsAre());
EXPECT_THAT(subsequence_finder.FindSubsequences("abdc"),
UnorderedElementsAre());
}
TEST(SubsequenceFinderTest, Multiple) {
SubsequenceFinder subsequence_finder(3);
subsequence_finder.AddSubsequence("a b c d", 0);
subsequence_finder.AddSubsequence("q r s", 2);
subsequence_finder.AddSubsequence("b c d e", 4);
EXPECT_THAT(subsequence_finder.FindSubsequences("a__b__c__d__e"),
UnorderedElementsAre(0, 4));
EXPECT_THAT(subsequence_finder.FindSubsequences("aqbrcsd"),
UnorderedElementsAre(0, 2));
EXPECT_THAT(subsequence_finder.FindSubsequences("b q c r d s e"),
UnorderedElementsAre(2, 4));
}
TEST(SubsequenceFinderTest, Utf8) {
SubsequenceFinder subsequence_finder(3);
subsequence_finder.AddSubsequence("一二 三四 五六", 0);
EXPECT_THAT(subsequence_finder.FindSubsequences("一二おはよ三四こんに五六"),
UnorderedElementsAre(0));
EXPECT_THAT(subsequence_finder.FindSubsequences("一二三 四五六"),
UnorderedElementsAre());
}
} // namespace
} // namespace seq_flow_lite
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_TEXT_DISTORTER_H_ #ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_TEXT_DISTORTER_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_TEXT_DISTORTER_H_ #define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_TEXT_DISTORTER_H_
#include <assert.h> #include <assert.h>
...@@ -40,4 +40,4 @@ class TextDistorter { ...@@ -40,4 +40,4 @@ class TextDistorter {
UChar32 random_char_ = 0; UChar32 random_char_ = 0;
}; };
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_TEXT_DISTORTER_H_ #endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_TEXT_DISTORTER_H_
...@@ -122,4 +122,4 @@ REGISTER_OP("UniformCausalAttn") ...@@ -122,4 +122,4 @@ REGISTER_OP("UniformCausalAttn")
}) })
.Doc(R"doc( .Doc(R"doc(
Dummy uniform causal attn op. Dummy uniform causal attn op.
)doc"; )doc");
...@@ -121,9 +121,9 @@ cc_library( ...@@ -121,9 +121,9 @@ cc_library(
hdrs = ["tflite_qrnn_pooling.h"], hdrs = ["tflite_qrnn_pooling.h"],
copts = tflite_copts(), copts = tflite_copts(),
deps = [ deps = [
"//third_party/absl/base:core_headers", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"//third_party/tensorflow/lite/kernels:builtin_ops", "//tflite_ops:quantization_util", # sequence projection
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util", "@com_google_absl//absl/base:core_headers",
], ],
alwayslink = 1, alwayslink = 1,
) )
...@@ -132,7 +132,7 @@ cc_library( ...@@ -132,7 +132,7 @@ cc_library(
name = "tflite_decoder_cache", name = "tflite_decoder_cache",
hdrs = ["tflite_decoder_cache.h"], hdrs = ["tflite_decoder_cache.h"],
deps = [ deps = [
"//third_party/tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite/c:common",
], ],
alwayslink = 1, alwayslink = 1,
) )
...@@ -144,12 +144,12 @@ cc_library( ...@@ -144,12 +144,12 @@ cc_library(
copts = tflite_copts(), copts = tflite_copts(),
deps = [ deps = [
":tflite_decoder_cache", ":tflite_decoder_cache",
"//third_party/flatbuffers", "@org_tensorflow//tensorflow/lite/c:common",
"//third_party/tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"//third_party/tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:kernel_util",
"//third_party/tensorflow/lite/kernels:kernel_util", "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
"//third_party/tensorflow/lite/kernels/internal:tensor", "//tflite_ops:quantization_util", # sequence projection
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util", "@flatbuffers",
], ],
alwayslink = 1, alwayslink = 1,
) )
...@@ -160,11 +160,11 @@ cc_test( ...@@ -160,11 +160,11 @@ cc_test(
srcs = ["tflite_decoder_handler_test.cc"], srcs = ["tflite_decoder_handler_test.cc"],
deps = [ deps = [
":tflite_decoder_handler", ":tflite_decoder_handler",
"//testing/base/public:gunit", "@org_tensorflow//tensorflow/lite:framework",
"//third_party/flatbuffers", "@org_tensorflow//tensorflow/lite/c:common",
"//third_party/tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/kernels:test_util",
"//third_party/tensorflow/lite/c:common", "@com_google_googletest//:gtest",
"//third_party/tensorflow/lite/kernels:test_util", "@flatbuffers",
], ],
) )
...@@ -176,10 +176,10 @@ cc_library( ...@@ -176,10 +176,10 @@ cc_library(
deps = [ deps = [
"//base", "//base",
"//third_party/absl/strings", "//third_party/absl/strings",
"//third_party/tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite/c:common",
"//third_party/tensorflow/lite/kernels/internal:tensor", "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
"//third_party/tensorflow/lite/kernels/internal:types", "@org_tensorflow//tensorflow/lite/kernels/internal:types",
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util", "//tflite_ops:quantization_util", # sequence projection
], ],
) )
...@@ -189,14 +189,14 @@ cc_test( ...@@ -189,14 +189,14 @@ cc_test(
copts = tflite_copts(), copts = tflite_copts(),
deps = [ deps = [
":beam_search", ":beam_search",
"//testing/base/public:gunit_main",
"//third_party/absl/strings", "//third_party/absl/strings",
"//third_party/tensorflow/lite/c:c_api_types", "@org_tensorflow//tensorflow/lite/c:c_api_types",
"//third_party/tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite/c:common",
"//third_party/tensorflow/lite/kernels/internal:legacy_reference_base", "@org_tensorflow//tensorflow/lite/kernels/internal:legacy_reference_base",
"//third_party/tensorflow/lite/kernels/internal:optimized_base", "@org_tensorflow//tensorflow/lite/kernels/internal:optimized_base",
"//third_party/tensorflow/lite/kernels/internal:tensor", "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
"//third_party/tensorflow/lite/kernels/internal:types", "@org_tensorflow//tensorflow/lite/kernels/internal:types",
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util", "//tflite_ops:quantization_util", # sequence projection
"@com_google_googletest//:gtest_main",
], ],
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment