Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
20cc2190
Unverified
Commit
20cc2190
authored
Aug 24, 2022
by
pyoung2778
Committed by
GitHub
Aug 24, 2022
Browse files
Check in seq_flow_lite (#10750)
parent
fdecf385
Changes
62
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1698 additions
and
80 deletions
+1698
-80
research/seq_flow_lite/models/transformer_encoder.py
research/seq_flow_lite/models/transformer_encoder.py
+2
-2
research/seq_flow_lite/models/transformer_uniform_attn_decoder.py
.../seq_flow_lite/models/transformer_uniform_attn_decoder.py
+6
-6
research/seq_flow_lite/tf_ops/BUILD
research/seq_flow_lite/tf_ops/BUILD
+102
-26
research/seq_flow_lite/tf_ops/denylist_op.cc
research/seq_flow_lite/tf_ops/denylist_op.cc
+438
-0
research/seq_flow_lite/tf_ops/denylist_op_test.cc
research/seq_flow_lite/tf_ops/denylist_op_test.cc
+292
-0
research/seq_flow_lite/tf_ops/denylist_op_test.py
research/seq_flow_lite/tf_ops/denylist_op_test.py
+63
-0
research/seq_flow_lite/tf_ops/projection_normalizer_util.cc
research/seq_flow_lite/tf_ops/projection_normalizer_util.cc
+29
-1
research/seq_flow_lite/tf_ops/projection_normalizer_util.h
research/seq_flow_lite/tf_ops/projection_normalizer_util.h
+10
-6
research/seq_flow_lite/tf_ops/projection_tokenizer_util.h
research/seq_flow_lite/tf_ops/projection_tokenizer_util.h
+3
-3
research/seq_flow_lite/tf_ops/projection_util.h
research/seq_flow_lite/tf_ops/projection_util.h
+3
-3
research/seq_flow_lite/tf_ops/sequence_string_projection.cc
research/seq_flow_lite/tf_ops/sequence_string_projection.cc
+10
-2
research/seq_flow_lite/tf_ops/skipgram_finder.cc
research/seq_flow_lite/tf_ops/skipgram_finder.cc
+183
-0
research/seq_flow_lite/tf_ops/skipgram_finder.h
research/seq_flow_lite/tf_ops/skipgram_finder.h
+66
-0
research/seq_flow_lite/tf_ops/skipgram_finder_test.cc
research/seq_flow_lite/tf_ops/skipgram_finder_test.cc
+160
-0
research/seq_flow_lite/tf_ops/subsequence_finder.cc
research/seq_flow_lite/tf_ops/subsequence_finder.cc
+143
-0
research/seq_flow_lite/tf_ops/subsequence_finder.h
research/seq_flow_lite/tf_ops/subsequence_finder.h
+76
-0
research/seq_flow_lite/tf_ops/subsequence_finder_test.cc
research/seq_flow_lite/tf_ops/subsequence_finder_test.cc
+81
-0
research/seq_flow_lite/tf_ops/text_distorter.h
research/seq_flow_lite/tf_ops/text_distorter.h
+3
-3
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
+1
-1
research/seq_flow_lite/tflite_ops/BUILD
research/seq_flow_lite/tflite_ops/BUILD
+27
-27
No files found.
research/seq_flow_lite/models/transformer_encoder.py
View file @
20cc2190
...
@@ -18,8 +18,8 @@
...
@@ -18,8 +18,8 @@
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
layers
import
base_layers
from
layers
import
base_layers
# import seq_flow_lite module
from
layers
import
transformer_layers
from
layers
import
transformer_layers
# import seq_flow_lite module
class
Model
(
tf
.
keras
.
layers
.
Layer
):
class
Model
(
tf
.
keras
.
layers
.
Layer
):
...
...
research/seq_flow_lite/models/transformer_uniform_attn_decoder.py
View file @
20cc2190
...
@@ -20,12 +20,12 @@ from absl import logging
...
@@ -20,12 +20,12 @@ from absl import logging
from
tensor2tensor.utils
import
beam_search
from
tensor2tensor.utils
import
beam_search
import
tensorflow
as
tf
import
tensorflow
as
tf
from
layers
import
base_layers
from
layers
import
base_layers
# import seq_flow_lite module
from
layers
import
dense_layers
from
layers
import
dense_layers
# import seq_flow_lite module
from
layers
import
embedding_layers
from
layers
import
embedding_layers
# import seq_flow_lite module
from
layers
import
normalization_layers
from
layers
import
normalization_layers
# import seq_flow_lite module
from
layers
import
quantization_layers
from
layers
import
quantization_layers
# import seq_flow_lite module
from
layers
import
transformer_layers
from
layers
import
transformer_layers
# import seq_flow_lite module
class
TransformerUniformAttnDecoder
(
base_layers
.
BaseLayer
):
class
TransformerUniformAttnDecoder
(
base_layers
.
BaseLayer
):
...
...
research/seq_flow_lite/tf_ops/BUILD
View file @
20cc2190
...
@@ -11,20 +11,23 @@ package(
...
@@ -11,20 +11,23 @@ package(
)
)
cc_library
(
cc_library
(
name
=
"sequence_string_projection_op"
,
name
=
"projection_normalizer_util"
,
srcs
=
[
srcs
=
[
"projection_normalizer_util.cc"
],
"sequence_string_projection.cc"
,
hdrs
=
[
"projection_normalizer_util.h"
],
deps
=
[
":projection_util"
,
"@utf_archive//:utf"
,
],
],
)
cc_library
(
name
=
"projection_tokenizer_util"
,
srcs
=
[
"projection_tokenizer_util.cc"
],
hdrs
=
[
"projection_tokenizer_util.h"
],
deps
=
[
deps
=
[
":projection_normalizer_util"
,
":projection_tokenizer_util"
,
":projection_util"
,
":projection_util"
,
":text_distorter"
,
"@utf_archive//:utf"
,
"@com_google_absl//absl/container:flat_hash_map"
,
"@tensorflow_includes//:includes"
,
"@tensorflow_solib//:framework_lib"
,
],
],
alwayslink
=
1
,
)
)
cc_library
(
cc_library
(
...
@@ -37,22 +40,46 @@ cc_library(
...
@@ -37,22 +40,46 @@ cc_library(
)
)
cc_library
(
cc_library
(
name
=
"
projection_tokenizer_util
"
,
name
=
"
skipgram_finder
"
,
srcs
=
[
"
projection_tokenizer_util
.cc"
],
srcs
=
[
"
skipgram_finder
.cc"
],
hdrs
=
[
"
projection_tokenizer_util
.h"
],
hdrs
=
[
"
skipgram_finder
.h"
],
deps
=
[
deps
=
[
":projection_util"
,
"@com_google_absl//absl/container:flat_hash_map"
,
"@utf_archive//:utf"
,
"@com_google_absl//absl/container:flat_hash_set"
,
"@com_google_absl//absl/strings"
,
"@icu4c//:icu4c"
,
],
)
cc_test
(
name
=
"skipgram_finder_test"
,
srcs
=
[
"skipgram_finder_test.cc"
],
deps
=
[
":skipgram_finder"
,
"@com_google_absl//absl/strings"
,
"@com_google_googletest//:gtest_main"
,
"@icu4c//:icu4c"
,
],
],
)
)
cc_library
(
cc_library
(
name
=
"
projection_normalizer_util
"
,
name
=
"
subsequence_finder
"
,
srcs
=
[
"
projection_normalizer_util
.cc"
],
srcs
=
[
"
subsequence_finder
.cc"
],
hdrs
=
[
"
projection_normalizer_util
.h"
],
hdrs
=
[
"
subsequence_finder
.h"
],
deps
=
[
deps
=
[
":projection_util"
,
"@com_google_absl//absl/container:flat_hash_map"
,
"@utf_archive//:utf"
,
"@com_google_absl//absl/container:flat_hash_set"
,
"@com_google_absl//absl/strings"
,
"@icu4c//:icu4c"
,
],
)
cc_test
(
name
=
"subsequence_finder_test"
,
srcs
=
[
"subsequence_finder_test.cc"
],
deps
=
[
":subsequence_finder"
,
"@com_google_googletest//:gtest_main"
,
],
],
)
)
...
@@ -67,6 +94,55 @@ cc_library(
...
@@ -67,6 +94,55 @@ cc_library(
],
],
)
)
cc_library
(
name
=
"denylist_op"
,
srcs
=
[
"denylist_op.cc"
],
deps
=
[
":skipgram_finder"
,
":subsequence_finder"
,
"@com_google_absl//absl/cleanup"
,
"@com_google_absl//absl/container:flat_hash_set"
,
"@com_google_absl//absl/memory"
,
"@tensorflow_includes//:includes"
,
"@tensorflow_solib//:framework_lib"
,
],
alwayslink
=
1
,
)
gen_op_wrapper_py
(
name
=
"denylist_op_py"
,
out
=
"denylist_op.py"
,
kernel_lib
=
":denylist_op"
,
)
py_test
(
name
=
"denylist_op_py_test"
,
srcs
=
[
"denylist_op_test.py"
],
main
=
"denylist_op_test.py"
,
python_version
=
"PY3"
,
srcs_version
=
"PY3"
,
deps
=
[
":denylist_op_py"
,
],
)
cc_library
(
name
=
"sequence_string_projection_op"
,
srcs
=
[
"sequence_string_projection.cc"
,
],
deps
=
[
":projection_normalizer_util"
,
":projection_tokenizer_util"
,
":projection_util"
,
":text_distorter"
,
"@com_google_absl//absl/container:flat_hash_map"
,
"@tensorflow_includes//:includes"
,
"@tensorflow_solib//:framework_lib"
,
],
alwayslink
=
1
,
)
cc_test
(
cc_test
(
name
=
"sequence_string_projection_test"
,
name
=
"sequence_string_projection_test"
,
size
=
"small"
,
size
=
"small"
,
...
@@ -78,6 +154,12 @@ cc_test(
...
@@ -78,6 +154,12 @@ cc_test(
],
],
)
)
gen_op_wrapper_py
(
name
=
"sequence_string_projection_op_py"
,
out
=
"sequence_string_projection_op.py"
,
kernel_lib
=
":sequence_string_projection_op"
,
)
cc_library
(
cc_library
(
name
=
"sequence_string_projection_op_v2"
,
name
=
"sequence_string_projection_op_v2"
,
srcs
=
[
srcs
=
[
...
@@ -111,12 +193,6 @@ gen_op_wrapper_py(
...
@@ -111,12 +193,6 @@ gen_op_wrapper_py(
kernel_lib
=
":sequence_string_projection_op_v2"
,
kernel_lib
=
":sequence_string_projection_op_v2"
,
)
)
gen_op_wrapper_py
(
name
=
"sequence_string_projection_op_py"
,
out
=
"sequence_string_projection_op.py"
,
kernel_lib
=
":sequence_string_projection_op"
,
)
cc_library
(
cc_library
(
name
=
"tf_custom_ops"
,
name
=
"tf_custom_ops"
,
srcs
=
[
"tf_custom_ops.cc"
],
srcs
=
[
"tf_custom_ops.cc"
],
...
...
research/seq_flow_lite/tf_ops/denylist_op.cc
0 → 100644
View file @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tf_ops/skipgram_finder.h" // seq_flow_lite
#include "tf_ops/subsequence_finder.h" // seq_flow_lite
namespace
seq_flow_lite
{
using
::
tensorflow
::
OpKernel
;
using
::
tensorflow
::
OpKernelConstruction
;
using
::
tensorflow
::
OpKernelContext
;
using
::
tensorflow
::
Status
;
using
::
tensorflow
::
Tensor
;
using
::
tensorflow
::
TensorShape
;
using
::
tensorflow
::
errors
::
InvalidArgument
;
using
::
tensorflow
::
shape_inference
::
InferenceContext
;
using
::
tensorflow
::
shape_inference
::
ShapeHandle
;
// Description of the outputs and attributes for the Denylist ops.
const
char
kDescription
[]
=
R"(
output: A floating point tensor that contains a prediction vector for each
input string. The vector will either be:
* [1, 1, ..., 0, 0, ...] if no denylisted skipgrams are found.
(All negative categories are 1.0 and all positive categories are 0.0.)
* an indicator vector if any denylisted skipgrams are found.
(0.0 if no skipgrams belonging to the category were found and 1.0 otherwise)
max_skip_size: The maximum number of tokens that can be skipped when generating
skipgrams.
denylist: A string vector containing denylisted skipgrams.
denylist_category: An int32 vector containing the category of the corresponding
skipgram in the denylist.
categories: An int32 scalar. This is the total number of categories.
All categories in denylist_category must be in [0, categories).
negative_categories: An int32 scalar. The total number of categories that
should be set if no entries in the denylist are triggered. These
negative categories are assumed to be [0, negative_categories).
)"
;
// The base class for all Denylist ops. It does two things:
// 1) It defines the output tensor of the op and it defines the attributes
// needed to specify the denylist and convert denylist categories into
// output vectors.
// 2) It defines a Compute() function. The compute function is responsible
// for filling in the output tensor, while the subclass is responsible
// for processing the input.
class
DenylistOpBase
:
public
OpKernel
{
public:
explicit
DenylistOpBase
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"categories"
,
&
categories_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"negative_categories"
,
&
negative_categories_
));
OP_REQUIRES
(
context
,
categories_
>
0
,
InvalidArgument
(
"Number of categories ("
,
categories_
,
") must be positive."
));
OP_REQUIRES
(
context
,
negative_categories_
>=
0
,
InvalidArgument
(
"Number of negative_categories ("
,
negative_categories_
,
") must be non-negative."
));
OP_REQUIRES
(
context
,
negative_categories_
<
categories_
,
InvalidArgument
(
"Number of categories ("
,
categories_
,
") must be greater than the "
"number of negative_categories ("
,
negative_categories_
,
")."
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"max_skip_size"
,
&
max_skip_size_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"denylist"
,
&
denylist_
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"denylist_category"
,
&
denylist_category_
));
OP_REQUIRES
(
context
,
denylist_
.
size
()
==
denylist_category_
.
size
(),
InvalidArgument
(
"denylist length ("
,
denylist_
.
size
(),
") != denylist_category length ("
,
denylist_category_
.
size
(),
")"
));
int
max
=
*
std
::
max_element
(
denylist_category_
.
begin
(),
denylist_category_
.
end
());
OP_REQUIRES
(
context
,
max
<
categories_
,
InvalidArgument
(
"max element of denylist_category ("
,
max
,
") >= categories ("
,
categories_
,
")"
));
int
min
=
*
std
::
min_element
(
denylist_category_
.
begin
(),
denylist_category_
.
end
());
OP_REQUIRES
(
context
,
min
>=
0
,
InvalidArgument
(
"min element of denylist_category ("
,
min
,
") < 0"
));
}
void
Compute
(
OpKernelContext
*
context
)
override
{
auto
compute_context
=
InitializeComputeContext
(
context
);
if
(
compute_context
==
nullptr
)
{
return
;
}
auto
context_cleaner
=
absl
::
MakeCleanup
([
this
,
compute_context
]
{
this
->
FinalizeComputeContext
(
compute_context
);
});
Tensor
*
output_tensor
;
TensorShape
output_shape
=
InputStringsShape
(
compute_context
);
output_shape
.
AddDim
(
categories_
);
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
"output"
,
output_shape
,
&
output_tensor
));
auto
output_values
=
output_tensor
->
flat
<
float
>
();
for
(
int
i
=
0
;
i
<
NumInputStrings
(
compute_context
);
i
++
)
{
auto
category
=
GetCategories
(
i
,
compute_context
);
int
base_index
=
i
*
categories_
;
if
(
category
.
empty
())
{
for
(
int
j
=
0
;
j
<
categories_
;
j
++
)
{
output_values
(
base_index
+
j
)
=
j
<
negative_categories_
?
1.0
:
0.0
;
}
}
else
{
for
(
int
j
=
0
;
j
<
categories_
;
j
++
)
{
output_values
(
base_index
+
j
)
=
category
.
contains
(
j
)
?
1.0
:
0.0
;
}
}
}
}
protected:
int
max_skip_size
()
{
return
max_skip_size_
;
}
int
denylist_size
()
{
return
denylist_
.
size
();
}
const
std
::
string
&
denylist
(
int
i
)
{
return
denylist_
[
i
];
}
int32_t
denylist_category
(
int
i
)
{
return
denylist_category_
[
i
];
}
private:
// Called at the beginning of Compute(). This function should process
// the input and return a context object that can be used to identify
// the denylist categories of each input string.
virtual
void
*
InitializeComputeContext
(
OpKernelContext
*
context
)
=
0
;
// Called at the end of Compute(). Frees the context object.
virtual
void
FinalizeComputeContext
(
void
*
context
)
=
0
;
// Returns the shape of the input tensor, if it only consisted of strings.
// If the input tensor is strings, this is the shape of the input tensor.
// If the input tensor is tokens, this is the shape of the input tensor,
// minus the innermost dimension.
virtual
TensorShape
InputStringsShape
(
void
*
context
)
=
0
;
// Returns the number of strings in the input tensor.
virtual
int
NumInputStrings
(
void
*
context
)
=
0
;
// Returns the denylist categories of the index-th string.
virtual
absl
::
flat_hash_set
<
int
>
GetCategories
(
int
index
,
void
*
context
)
=
0
;
int32_t
categories_
;
int32_t
negative_categories_
;
int
max_skip_size_
;
std
::
vector
<
std
::
string
>
denylist_
;
std
::
vector
<
int32_t
>
denylist_category_
;
};
// A base class for Denylist ops that expect a string tensor input.
class
StringDenylistOp
:
public
DenylistOpBase
{
public:
explicit
StringDenylistOp
(
OpKernelConstruction
*
context
)
:
DenylistOpBase
(
context
)
{}
private:
void
*
InitializeComputeContext
(
OpKernelContext
*
context
)
override
{
const
Tensor
*
input_tensor
;
auto
status
=
context
->
input
(
"input"
,
&
input_tensor
);
if
(
!
status
.
ok
())
{
context
->
CtxFailureWithWarning
(
__FILE__
,
__LINE__
,
status
);
return
nullptr
;
}
return
new
ComputeContext
(
input_tensor
);
}
void
FinalizeComputeContext
(
void
*
context
)
override
{
delete
static_cast
<
ComputeContext
*>
(
context
);
}
TensorShape
InputStringsShape
(
void
*
context
)
override
{
return
static_cast
<
ComputeContext
*>
(
context
)
->
input_tensor
->
shape
();
}
int
NumInputStrings
(
void
*
context
)
override
{
return
static_cast
<
ComputeContext
*>
(
context
)
->
input_tensor_values
.
size
();
}
absl
::
flat_hash_set
<
int
>
GetCategories
(
int
index
,
void
*
context
)
override
{
return
FindTerms
(
static_cast
<
ComputeContext
*>
(
context
)
->
input_tensor_values
(
index
));
}
struct
ComputeContext
{
ComputeContext
(
const
Tensor
*
input_tensor
)
:
input_tensor
(
input_tensor
),
input_tensor_values
(
input_tensor
->
flat
<::
tensorflow
::
tstring
>
())
{}
const
Tensor
*
input_tensor
;
::
tensorflow
::
TTypes
<::
tensorflow
::
tstring
>::
ConstFlat
input_tensor_values
;
};
// Returns the set of denylist categories for the input string.
virtual
absl
::
flat_hash_set
<
int
>
FindTerms
(
const
std
::
string
&
input
)
=
0
;
};
// A denylist op that uses the SkipgramFinder on string inputs.
class
SkipgramDenylistOp
:
public
StringDenylistOp
{
public:
explicit
SkipgramDenylistOp
(
OpKernelConstruction
*
context
)
:
StringDenylistOp
(
context
)
{
skipgram_finder_
=
std
::
make_unique
<
SkipgramFinder
>
(
max_skip_size
());
for
(
int
i
=
0
;
i
<
denylist_size
();
i
++
)
{
skipgram_finder_
->
AddSkipgram
(
denylist
(
i
),
denylist_category
(
i
));
}
}
private:
absl
::
flat_hash_set
<
int
>
FindTerms
(
const
std
::
string
&
input
)
override
{
return
skipgram_finder_
->
FindSkipgrams
(
input
);
}
std
::
unique_ptr
<
SkipgramFinder
>
skipgram_finder_
;
};
REGISTER_KERNEL_BUILDER
(
Name
(
"SkipgramDenylist"
).
Device
(
::
tensorflow
::
DEVICE_CPU
),
SkipgramDenylistOp
);
// Shape inference function for Denylist ops with string inputs.
Status
StringDenylistShapeFn
(
InferenceContext
*
context
)
{
int32_t
categories
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"categories"
,
&
categories
));
ShapeHandle
output_shape
;
TF_RETURN_IF_ERROR
(
context
->
Concatenate
(
context
->
input
(
0
),
context
->
MakeShape
({
categories
}),
&
output_shape
));
context
->
set_output
(
0
,
output_shape
);
return
::
tensorflow
::
Status
::
OK
();
}
REGISTER_OP
(
"SkipgramDenylist"
)
.
Input
(
"input: string"
)
.
Output
(
"output: float"
)
.
Attr
(
"max_skip_size: int"
)
.
Attr
(
"denylist: list(string)"
)
.
Attr
(
"denylist_category: list(int)"
)
.
Attr
(
"categories: int"
)
.
Attr
(
"negative_categories: int"
)
.
SetShapeFn
(
StringDenylistShapeFn
)
.
Doc
(
absl
::
StrCat
(
"Generates dense prediction vectors for input strings "
"using a skipgram denylist."
,
"
\n\n
"
,
"input: A string tensor."
,
"
\n\n
"
,
kDescription
));
// A Denylist op that uses the SubsequenceFinder on string inputs.
class
SubsequenceDenylistOp
:
public
StringDenylistOp
{
public:
explicit
SubsequenceDenylistOp
(
OpKernelConstruction
*
context
)
:
StringDenylistOp
(
context
)
{
subsequence_finder_
=
std
::
make_unique
<
SubsequenceFinder
>
(
max_skip_size
());
for
(
int
i
=
0
;
i
<
denylist_size
();
i
++
)
{
subsequence_finder_
->
AddSubsequence
(
denylist
(
i
),
denylist_category
(
i
));
}
}
private:
absl
::
flat_hash_set
<
int
>
FindTerms
(
const
std
::
string
&
input
)
override
{
return
subsequence_finder_
->
FindSubsequences
(
input
);
}
std
::
unique_ptr
<
SubsequenceFinder
>
subsequence_finder_
;
};
REGISTER_KERNEL_BUILDER
(
Name
(
"SubsequenceDenylist"
).
Device
(
::
tensorflow
::
DEVICE_CPU
),
SubsequenceDenylistOp
);
REGISTER_OP
(
"SubsequenceDenylist"
)
.
Input
(
"input: string"
)
.
Output
(
"output: float"
)
.
Attr
(
"max_skip_size: int"
)
.
Attr
(
"denylist: list(string)"
)
.
Attr
(
"denylist_category: list(int)"
)
.
Attr
(
"categories: int"
)
.
Attr
(
"negative_categories: int"
)
.
SetShapeFn
(
StringDenylistShapeFn
)
.
Doc
(
absl
::
StrCat
(
"Generates dense prediction vectors for inputs using a "
"subsequence denylist."
,
"
\n\n
"
,
"input: A string tensor."
,
"
\n\n
"
,
kDescription
));
// A denylist op that uses the SkipgramFinder on tokenized string inputs.
// The inputs are a pair of tensors: a token tensor of type string and
// a token count tensor of type T.
template
<
typename
T
>
class
TokenizedDenylistOp
:
public
DenylistOpBase
{
public:
explicit
TokenizedDenylistOp
(
OpKernelConstruction
*
context
)
:
DenylistOpBase
(
context
)
{
skipgram_finder_
=
std
::
make_unique
<
SkipgramFinder
>
(
max_skip_size
());
for
(
int
i
=
0
;
i
<
denylist_size
();
i
++
)
{
skipgram_finder_
->
AddSkipgram
(
denylist
(
i
),
denylist_category
(
i
));
}
}
private:
void
*
InitializeComputeContext
(
OpKernelContext
*
context
)
override
{
const
Tensor
*
input_tensor
;
{
auto
status
=
context
->
input
(
"input"
,
&
input_tensor
);
if
(
!
status
.
ok
())
{
context
->
CtxFailureWithWarning
(
__FILE__
,
__LINE__
,
status
);
return
nullptr
;
}
}
const
Tensor
*
token_count_tensor
;
{
auto
status
=
context
->
input
(
"token_count"
,
&
token_count_tensor
);
if
(
!
status
.
ok
())
{
context
->
CtxFailureWithWarning
(
__FILE__
,
__LINE__
,
status
);
return
nullptr
;
}
}
return
new
ComputeContext
(
input_tensor
,
token_count_tensor
);
}
void
FinalizeComputeContext
(
void
*
context
)
override
{
delete
static_cast
<
ComputeContext
*>
(
context
);
}
TensorShape
InputStringsShape
(
void
*
context
)
override
{
return
static_cast
<
ComputeContext
*>
(
context
)
->
shape
;
}
int
NumInputStrings
(
void
*
context
)
override
{
return
static_cast
<
ComputeContext
*>
(
context
)
->
size
;
}
absl
::
flat_hash_set
<
int
>
GetCategories
(
int
index
,
void
*
x
)
override
{
ComputeContext
*
context
=
static_cast
<
ComputeContext
*>
(
x
);
int64_t
num_tokens
=
context
->
token_count_flat
(
index
);
std
::
vector
<
absl
::
string_view
>
tokens
;
tokens
.
reserve
(
num_tokens
);
int64_t
start
=
index
*
context
->
max_tokens
;
for
(
int64_t
i
=
start
;
i
<
start
+
num_tokens
;
i
++
)
{
tokens
.
emplace_back
(
context
->
token_flat
(
i
).
data
(),
context
->
token_flat
(
i
).
size
());
}
return
skipgram_finder_
->
FindSkipgrams
(
tokens
);
}
struct
ComputeContext
{
ComputeContext
(
const
Tensor
*
token_tensor
,
const
Tensor
*
token_count_tensor
)
:
token_flat
(
token_tensor
->
flat
<::
tensorflow
::
tstring
>
()),
token_count_flat
(
token_count_tensor
->
flat
<
T
>
())
{
shape
=
token_tensor
->
shape
();
max_tokens
=
shape
.
dim_size
(
shape
.
dims
()
-
1
);
shape
.
RemoveLastDims
(
1
);
size
=
1
;
for
(
int64_t
i
=
0
;
i
<
shape
.
dims
();
i
++
)
{
size
=
size
*
shape
.
dim_size
(
i
);
}
}
const
typename
::
tensorflow
::
TTypes
<::
tensorflow
::
tstring
>::
ConstFlat
token_flat
;
const
typename
::
tensorflow
::
TTypes
<
T
>::
ConstFlat
token_count_flat
;
TensorShape
shape
;
int64_t
size
;
int64_t
max_tokens
;
};
std
::
unique_ptr
<
SkipgramFinder
>
skipgram_finder_
;
};
REGISTER_KERNEL_BUILDER
(
Name
(
"TokenizedDenylist"
)
.
Device
(
::
tensorflow
::
DEVICE_CPU
)
.
TypeConstraint
<
int32_t
>
(
"Ttoken_count"
),
TokenizedDenylistOp
<
int32_t
>
);
REGISTER_KERNEL_BUILDER
(
Name
(
"TokenizedDenylist"
)
.
Device
(
::
tensorflow
::
DEVICE_CPU
)
.
TypeConstraint
<
int64_t
>
(
"Ttoken_count"
),
TokenizedDenylistOp
<
int64_t
>
);
// Shape inference function for Denylist ops with tokenized string inputs.
Status
TokenizedDenylistShapeFn
(
InferenceContext
*
context
)
{
int32_t
categories
;
TF_RETURN_IF_ERROR
(
context
->
GetAttr
(
"categories"
,
&
categories
));
ShapeHandle
string_tensor_shape
;
TF_RETURN_IF_ERROR
(
context
->
Subshape
(
context
->
input
(
0
),
0
,
-
1
,
&
string_tensor_shape
));
ShapeHandle
output_shape
;
TF_RETURN_IF_ERROR
(
context
->
Concatenate
(
string_tensor_shape
,
context
->
MakeShape
({
categories
}),
&
output_shape
));
context
->
set_output
(
0
,
output_shape
);
return
::
tensorflow
::
Status
::
OK
();
}
REGISTER_OP
(
"TokenizedDenylist"
)
.
Input
(
"input: string"
)
.
Input
(
"token_count: Ttoken_count"
)
.
Output
(
"output: float"
)
.
Attr
(
"max_skip_size: int"
)
.
Attr
(
"denylist: list(string)"
)
.
Attr
(
"denylist_category: list(int)"
)
.
Attr
(
"categories: int"
)
.
Attr
(
"negative_categories: int"
)
.
Attr
(
"Ttoken_count: {int32, int64}"
)
.
SetShapeFn
(
TokenizedDenylistShapeFn
)
.
Doc
(
absl
::
StrCat
(
"Generates dense prediction vectors for tokens using a "
"skipgram denylist."
,
"
\n\n
"
,
"input: A string tensor of tokens."
,
"
\n\n
"
,
kDescription
));
}
// namespace seq_flow_lite
research/seq_flow_lite/tf_ops/denylist_op_test.cc
0 → 100644
View file @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.proto.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace
seq_flow_lite
{
namespace
{
using
::
tensorflow
::
DT_FLOAT
;
using
::
tensorflow
::
DT_INT32
;
using
::
tensorflow
::
DT_INT64
;
using
::
tensorflow
::
DT_STRING
;
using
::
tensorflow
::
NodeDefBuilder
;
using
::
tensorflow
::
OpsTestBase
;
using
::
tensorflow
::
Tensor
;
using
::
tensorflow
::
TensorShape
;
using
::
tensorflow
::
errors
::
InvalidArgument
;
using
::
tensorflow
::
test
::
ExpectTensorEqual
;
using
::
tensorflow
::
test
::
FillValues
;
class
SkipgramDenylistOpTest
:
public
OpsTestBase
{};
TEST_F
(
SkipgramDenylistOpTest
,
Correct
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SkipgramDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
2
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
AddInputFromArray
<::
tensorflow
::
tstring
>
(
TensorShape
({
2
}),
{
"q a q b q c q"
,
"q a b q q c"
});
TF_ASSERT_OK
(
RunOpKernel
());
const
Tensor
&
output
=
*
GetOutput
(
0
);
Tensor
expected
(
allocator
(),
DT_FLOAT
,
TensorShape
({
2
,
2
}));
FillValues
<
float
>
(
&
expected
,
{
0.0
,
1.0
,
1.0
,
0.0
});
ExpectTensorEqual
<
float
>
(
expected
,
output
);
}
TEST_F
(
SkipgramDenylistOpTest
,
Prefix
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SkipgramDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b.* c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
2
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
AddInputFromArray
<::
tensorflow
::
tstring
>
(
TensorShape
({
2
}),
{
"q a q bq q c q"
,
"q a bq q q c"
});
TF_ASSERT_OK
(
RunOpKernel
());
const
Tensor
&
output
=
*
GetOutput
(
0
);
Tensor
expected
(
allocator
(),
DT_FLOAT
,
TensorShape
({
2
,
2
}));
FillValues
<
float
>
(
&
expected
,
{
0.0
,
1.0
,
1.0
,
0.0
});
ExpectTensorEqual
<
float
>
(
expected
,
output
);
}
TEST_F
(
SkipgramDenylistOpTest
,
ZeroCategories
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SkipgramDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
0
)
.
Attr
(
"negative_categories"
,
0
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of categories (0) must be positive."
));
}
TEST_F
(
SkipgramDenylistOpTest
,
NegativeCategoriesLessThanZero
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SkipgramDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
1
)
.
Attr
(
"negative_categories"
,
-
1
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of negative_categories (-1) must be non-negative."
));
}
TEST_F
(
SkipgramDenylistOpTest
,
CategoriesEqualNegativeCategories
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SkipgramDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
1
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of categories (1) must be greater than the "
"number of negative_categories (1)."
));
}
class
SubsequenceDenylistOpTest
:
public
OpsTestBase
{};
TEST_F
(
SubsequenceDenylistOpTest
,
Correct
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SubsequenceDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
2
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
AddInputFromArray
<::
tensorflow
::
tstring
>
(
TensorShape
({
2
}),
{
"qaqbqcq"
,
"qabqqc"
});
TF_ASSERT_OK
(
RunOpKernel
());
const
Tensor
&
output
=
*
GetOutput
(
0
);
Tensor
expected
(
allocator
(),
DT_FLOAT
,
TensorShape
({
2
,
2
}));
FillValues
<
float
>
(
&
expected
,
{
0.0
,
1.0
,
1.0
,
0.0
});
ExpectTensorEqual
<
float
>
(
expected
,
output
);
}
TEST_F
(
SubsequenceDenylistOpTest
,
ZeroCategories
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SubsequenceDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
0
)
.
Attr
(
"negative_categories"
,
0
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of categories (0) must be positive."
));
}
TEST_F
(
SubsequenceDenylistOpTest
,
NegativeCategoriesLessThanZero
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SubsequenceDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
1
)
.
Attr
(
"negative_categories"
,
-
1
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of negative_categories (-1) must be non-negative."
));
}
TEST_F
(
SubsequenceDenylistOpTest
,
CategoriesEqualNegativeCategories
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"SubsequenceDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
1
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of categories (1) must be greater than the "
"number of negative_categories (1)."
));
}
class
TokenizedDenylistOpTest
:
public
OpsTestBase
{};
TEST_F
(
TokenizedDenylistOpTest
,
CorrectInt64TokenCount
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"TokenizedDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Input
({
"token_count"
,
0
,
DT_INT64
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
2
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
AddInputFromArray
<::
tensorflow
::
tstring
>
(
TensorShape
({
2
,
7
}),
{
"q"
,
"a"
,
"q"
,
"b"
,
"q"
,
"c"
,
"q"
,
//
"q"
,
"a"
,
"b"
,
"q"
,
"q"
,
"c"
,
""
});
AddInputFromArray
<
int64_t
>
(
TensorShape
({
2
}),
{
7
,
6
});
TF_ASSERT_OK
(
RunOpKernel
());
const
Tensor
&
output
=
*
GetOutput
(
0
);
Tensor
expected
(
allocator
(),
DT_FLOAT
,
TensorShape
({
2
,
2
}));
FillValues
<
float
>
(
&
expected
,
{
0.0
,
1.0
,
1.0
,
0.0
});
ExpectTensorEqual
<
float
>
(
expected
,
output
);
}
TEST_F
(
TokenizedDenylistOpTest
,
CorrectInt32TokenCount
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"TokenizedDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Input
({
"token_count"
,
0
,
DT_INT32
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
2
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
AddInputFromArray
<::
tensorflow
::
tstring
>
(
TensorShape
({
2
,
7
}),
{
"q"
,
"a"
,
"q"
,
"b"
,
"q"
,
"c"
,
"q"
,
//
"q"
,
"a"
,
"b"
,
"q"
,
"q"
,
"c"
,
""
});
AddInputFromArray
<
int32_t
>
(
TensorShape
({
2
}),
{
7
,
6
});
TF_ASSERT_OK
(
RunOpKernel
());
const
Tensor
&
output
=
*
GetOutput
(
0
);
Tensor
expected
(
allocator
(),
DT_FLOAT
,
TensorShape
({
2
,
2
}));
FillValues
<
float
>
(
&
expected
,
{
0.0
,
1.0
,
1.0
,
0.0
});
ExpectTensorEqual
<
float
>
(
expected
,
output
);
}
TEST_F
(
TokenizedDenylistOpTest
,
ZeroCategories
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"TokenizedDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Input
({
"token_count"
,
0
,
DT_INT64
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
0
)
.
Attr
(
"negative_categories"
,
0
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of categories (0) must be positive."
));
}
TEST_F
(
TokenizedDenylistOpTest
,
NegativeCategoriesLessThanZero
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"TokenizedDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Input
({
"token_count"
,
0
,
DT_INT64
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
1
)
.
Attr
(
"negative_categories"
,
-
1
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of negative_categories (-1) must be non-negative."
));
}
TEST_F
(
TokenizedDenylistOpTest
,
CategoriesEqualNegativeCategories
)
{
TF_ASSERT_OK
(
NodeDefBuilder
(
"test_op"
,
"TokenizedDenylist"
)
.
Input
({
"input"
,
0
,
DT_STRING
})
.
Input
({
"token_count"
,
0
,
DT_INT64
})
.
Attr
(
"max_skip_size"
,
1
)
.
Attr
(
"denylist"
,
{
"a b c"
})
.
Attr
(
"denylist_category"
,
{
1
})
.
Attr
(
"categories"
,
1
)
.
Attr
(
"negative_categories"
,
1
)
.
Finalize
(
node_def
()));
EXPECT_EQ
(
InitOp
(),
InvalidArgument
(
"Number of categories (1) must be greater than the "
"number of negative_categories (1)."
));
}
}
// namespace
}
// namespace seq_flow_lite
research/seq_flow_lite/tf_ops/denylist_op_test.py
0 → 100644
View file @
20cc2190
# Copyright 2022 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test denylist op and show example usage from python wrapper."""
import
tensorflow
as
tf
from
tf_ops
import
denylist_op
# import seq_flow_lite module
class
SkipgramDenylistTest
(
tf
.
test
.
TestCase
):
def
test_correct
(
self
):
result
=
denylist_op
.
skipgram_denylist
(
input
=
[
"q a q b q c q"
,
"q a b q q c"
],
max_skip_size
=
1
,
denylist
=
[
"a b c"
],
denylist_category
=
[
1
],
categories
=
2
,
negative_categories
=
1
)
self
.
assertAllEqual
(
result
,
[[
0.0
,
1.0
],
[
1.0
,
0.0
]])
class
SubsequenceDenylistTest
(
tf
.
test
.
TestCase
):
def
test_correct
(
self
):
result
=
denylist_op
.
subsequence_denylist
(
input
=
[
"qaqbqcq"
,
"qabqqc"
],
max_skip_size
=
1
,
denylist
=
[
"a b c"
],
denylist_category
=
[
1
],
categories
=
2
,
negative_categories
=
1
)
self
.
assertAllEqual
(
result
,
[[
0.0
,
1.0
],
[
1.0
,
0.0
]])
class
TokenizedDenylistTest
(
tf
.
test
.
TestCase
):
def
test_correct
(
self
):
result
=
denylist_op
.
tokenized_denylist
(
input
=
[[
"q"
,
"a"
,
"q"
,
"b"
,
"q"
,
"c"
,
"q"
],
[
"q"
,
"a"
,
"b"
,
"q"
,
"q"
,
"c"
,
""
]],
token_count
=
[
7
,
6
],
max_skip_size
=
1
,
denylist
=
[
"a b c"
],
denylist_category
=
[
1
],
categories
=
2
,
negative_categories
=
1
)
self
.
assertAllEqual
(
result
,
[[
0.0
,
1.0
],
[
1.0
,
0.0
]])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
research/seq_flow_lite/tf_ops/projection_normalizer_util.cc
View file @
20cc2190
...
@@ -26,7 +26,7 @@ limitations under the License.
...
@@ -26,7 +26,7 @@ limitations under the License.
bool
IsDigit
(
const
std
::
string
&
text
)
{
bool
IsDigit
(
const
std
::
string
&
text
)
{
Rune
rune
;
Rune
rune
;
for
(
size_t
i
=
0
;
i
<
text
.
length
();)
{
for
(
size_t
i
=
0
;
i
<
text
.
length
();)
{
const
int
bytes_read
=
chartorune
(
&
rune
,
const_cast
<
char
*>
(
text
.
data
()));
const
int
bytes_read
=
chartorune
(
&
rune
,
const_cast
<
char
*>
(
text
.
data
()));
if
(
rune
==
Runeerror
||
bytes_read
==
0
)
break
;
if
(
rune
==
Runeerror
||
bytes_read
==
0
)
break
;
if
(
rune
>=
static_cast
<
Rune
>
(
'0'
)
&&
rune
<=
static_cast
<
Rune
>
(
'9'
))
{
if
(
rune
>=
static_cast
<
Rune
>
(
'0'
)
&&
rune
<=
static_cast
<
Rune
>
(
'9'
))
{
return
true
;
return
true
;
...
@@ -98,6 +98,29 @@ std::string ContractToken(const char* input_ptr, size_t len, size_t num_chars) {
...
@@ -98,6 +98,29 @@ std::string ContractToken(const char* input_ptr, size_t len, size_t num_chars) {
return
token
;
return
token
;
}
}
void
NormalizeSpaces
(
std
::
string
&
input
)
{
// Whether to copy the next character if it's a space.
bool
copy_space
=
false
;
size_t
j
=
0
;
for
(
size_t
i
=
0
;
i
<
input
.
length
();
++
i
)
{
if
(
input
[
i
]
==
' '
)
{
if
(
!
copy_space
)
continue
;
copy_space
=
false
;
}
else
{
copy_space
=
true
;
}
if
(
j
!=
i
)
{
input
[
j
]
=
input
[
i
];
}
++
j
;
}
if
(
j
>
0
&&
input
[
j
-
1
]
==
' '
)
{
--
j
;
}
input
.
resize
(
j
);
}
void
ProjectionNormalizer
::
InitializeSeparators
(
const
std
::
string
&
separators
)
{
void
ProjectionNormalizer
::
InitializeSeparators
(
const
std
::
string
&
separators
)
{
for
(
size_t
i
=
0
;
i
<
separators
.
length
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
separators
.
length
();
++
i
)
{
if
(
separators
[
i
]
!=
' '
)
{
if
(
separators
[
i
]
!=
' '
)
{
...
@@ -150,9 +173,14 @@ std::string ProjectionNormalizer::Normalize(const char* input_ptr, size_t len,
...
@@ -150,9 +173,14 @@ std::string ProjectionNormalizer::Normalize(const char* input_ptr, size_t len,
normalized
=
ContractToken
(
normalized
.
data
(),
normalized
.
length
(),
3
);
normalized
=
ContractToken
(
normalized
.
data
(),
normalized
.
length
(),
3
);
}
}
if
(
normalize_spaces_
)
{
NormalizeSpaces
(
normalized
);
}
if
(
!
separators_
.
empty
())
{
if
(
!
separators_
.
empty
())
{
// Add space around separators_.
// Add space around separators_.
normalized
=
NormalizeInternal
(
normalized
.
data
(),
normalized
.
length
());
normalized
=
NormalizeInternal
(
normalized
.
data
(),
normalized
.
length
());
}
}
return
normalized
;
return
normalized
;
}
}
research/seq_flow_lite/tf_ops/projection_normalizer_util.h
View file @
20cc2190
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#include <string>
#include <string>
#include <unordered_set>
#include <unordered_set>
...
@@ -24,14 +24,17 @@ limitations under the License.
...
@@ -24,14 +24,17 @@ limitations under the License.
// Normalizes the input with the given |separators| by adding a space before and
// Normalizes the input with the given |separators| by adding a space before and
// after each separator. When |normalize_repetition| is true, it removes the
// after each separator. When |normalize_repetition| is true, it removes the
// repeated characters (except numbers) which consecutively appeared more than
// repeated characters (except numbers) which consecutively appeared more than
// twice in a word.
// twice in a word. When |normalize_spaces| is true, it removes spaces from
// the beginning and ending of the input, as well as repeated spaces.
// Examples: arwwwww -> arww, good!!!!! -> good!!, hahaha => haha.
// Examples: arwwwww -> arww, good!!!!! -> good!!, hahaha => haha.
class
ProjectionNormalizer
{
class
ProjectionNormalizer
{
public:
public:
explicit
ProjectionNormalizer
(
const
std
::
string
&
separators
,
explicit
ProjectionNormalizer
(
const
std
::
string
&
separators
,
bool
normalize_repetition
=
false
)
{
bool
normalize_repetition
=
false
,
bool
normalize_spaces
=
false
)
:
normalize_repetition_
(
normalize_repetition
),
normalize_spaces_
(
normalize_spaces
)
{
InitializeSeparators
(
separators
);
InitializeSeparators
(
separators
);
normalize_repetition_
=
normalize_repetition
;
}
}
// Normalizes the repeated characters (except numbers) which consecutively
// Normalizes the repeated characters (except numbers) which consecutively
...
@@ -49,6 +52,7 @@ class ProjectionNormalizer {
...
@@ -49,6 +52,7 @@ class ProjectionNormalizer {
std
::
unordered_set
<
char
>
separators_
;
std
::
unordered_set
<
char
>
separators_
;
bool
normalize_repetition_
;
bool
normalize_repetition_
;
bool
normalize_spaces_
;
};
};
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_NORMALIZER_UTIL_H_
research/seq_flow_lite/tf_ops/projection_tokenizer_util.h
View file @
20cc2190
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#include <string>
#include <string>
#include <unordered_set>
#include <unordered_set>
...
@@ -55,4 +55,4 @@ class ProjectionTokenizer {
...
@@ -55,4 +55,4 @@ class ProjectionTokenizer {
std
::
unordered_set
<
char
>
separators_
;
std
::
unordered_set
<
char
>
separators_
;
};
};
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_TOKENIZER_UTIL_H_
research/seq_flow_lite/tf_ops/projection_util.h
View file @
20cc2190
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_UTIL_H_
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
...
@@ -156,4 +156,4 @@ std::vector<std::string> SplitByChar(const char* input_ptr, size_t len,
...
@@ -156,4 +156,4 @@ std::vector<std::string> SplitByChar(const char* input_ptr, size_t len,
std
::
string
JoinPairsBySpace
(
std
::
vector
<
std
::
pair
<
const
char
*
,
size_t
>>
words
);
std
::
string
JoinPairsBySpace
(
std
::
vector
<
std
::
pair
<
const
char
*
,
size_t
>>
words
);
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_PROJECTION_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_PROJECTION_UTIL_H_
research/seq_flow_lite/tf_ops/sequence_string_projection.cc
View file @
20cc2190
...
@@ -109,11 +109,14 @@ class SequenceStringProjectionOp : public OpKernel {
...
@@ -109,11 +109,14 @@ class SequenceStringProjectionOp : public OpKernel {
bool
normalize_repetition
;
bool
normalize_repetition
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"normalize_repetition"
,
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"normalize_repetition"
,
&
normalize_repetition
));
&
normalize_repetition
));
bool
normalize_spaces
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"normalize_spaces"
,
&
normalize_spaces
));
std
::
string
separators
;
std
::
string
separators
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"token_separators"
,
&
separators
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"token_separators"
,
&
separators
));
if
(
!
separators
.
empty
()
||
normalize_repetition
)
{
if
(
!
separators
.
empty
()
||
normalize_repetition
||
normalize_spaces
)
{
projection_normalizer_
=
absl
::
make_unique
<
ProjectionNormalizer
>
(
projection_normalizer_
=
absl
::
make_unique
<
ProjectionNormalizer
>
(
separators
,
normalize_repetition
);
separators
,
normalize_repetition
,
normalize_spaces
);
}
}
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"add_first_cap_feature"
,
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"add_first_cap_feature"
,
...
@@ -326,6 +329,7 @@ REGISTER_OP("SequenceStringProjection")
...
@@ -326,6 +329,7 @@ REGISTER_OP("SequenceStringProjection")
.
Attr
(
"split_on_space: bool = True"
)
.
Attr
(
"split_on_space: bool = True"
)
.
Attr
(
"token_separators: string = ''"
)
.
Attr
(
"token_separators: string = ''"
)
.
Attr
(
"normalize_repetition: bool = false"
)
.
Attr
(
"normalize_repetition: bool = false"
)
.
Attr
(
"normalize_spaces: bool = false"
)
.
SetShapeFn
([](
InferenceContext
*
c
)
{
.
SetShapeFn
([](
InferenceContext
*
c
)
{
DimensionHandle
size
;
DimensionHandle
size
;
...
@@ -384,6 +388,10 @@ Attribute(s):
...
@@ -384,6 +388,10 @@ Attribute(s):
- add_all_caps_feature: Specifies the probability with which a feature to the
- add_all_caps_feature: Specifies the probability with which a feature to the
resulting projection tensor that helps discriminate if the input token is
resulting projection tensor that helps discriminate if the input token is
ALLCAPS will be added.
ALLCAPS will be added.
- normalize_repetition: When true normalizes repetition in text tokens before
fingerprinting.
- normalize_spaces: When true strips leading and trailing spaces and removes
repeated spaces.
Output(s):
Output(s):
- projection: Floating point tensor with ternary values of shape
- projection: Floating point tensor with ternary values of shape
...
...
research/seq_flow_lite/tf_ops/skipgram_finder.cc
0 → 100644
View file @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/skipgram_finder.h" // seq_flow_lite
#include <cctype>
#include <deque>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "icu4c/source/common/unicode/utf8.h"
namespace
seq_flow_lite
{
namespace
{
void
PreprocessToken
(
std
::
string
&
token
)
{
char
*
s
=
const_cast
<
char
*>
(
token
.
data
());
int32_t
size
=
token
.
size
();
int32_t
in
=
0
;
int32_t
out
=
0
;
while
(
in
<
size
)
{
UChar32
c
;
int32_t
old_in
=
in
;
U8_NEXT
(
s
,
in
,
size
,
c
);
if
(
c
<
0
)
{
break
;
}
if
(
u_ispunct
(
c
))
continue
;
UChar32
cl
=
u_tolower
(
c
);
// This is a hack, but there are exactly two unicode characters whose
// lowercase versions have longer UTF-8 encodings (0x23a to 0x2c65,
// 0x23e to 0x2c66). So, to avoid sizing issues, they're not lowercased.
if
(
U8_LENGTH
(
cl
)
>
(
in
-
old_in
))
{
cl
=
c
;
}
U8_APPEND_UNSAFE
(
s
,
out
,
cl
);
}
size_t
remaining
=
token
.
size
()
-
in
;
if
(
remaining
>
0
)
{
memmove
(
s
+
out
,
s
+
in
,
remaining
);
out
+=
remaining
;
}
token
.
resize
(
out
);
}
}
// namespace
void
SkipgramFinder
::
AddSkipgram
(
absl
::
string_view
skipgram
,
int
category
)
{
std
::
vector
<
std
::
string
>
tokens
=
absl
::
StrSplit
(
skipgram
,
' '
);
// Store the skipgram in a trie-like structure that uses tokens as the
// edge labels, instead of characters. Each node represents a skipgram made
// from the tokens used to reach the node, and stores the categories the
// skipgram is associated with.
TrieNode
*
cur
=
&
skipgram_trie_
;
for
(
auto
&
token
:
tokens
)
{
if
(
absl
::
EndsWith
(
token
,
".*"
))
{
token
.
resize
(
token
.
size
()
-
2
);
PreprocessToken
(
token
);
auto
iter
=
cur
->
prefix_to_node
.
find
(
token
);
if
(
iter
!=
cur
->
prefix_to_node
.
end
())
{
cur
=
&
iter
->
second
;
}
else
{
cur
=
&
cur
->
prefix_to_node
.
emplace
(
std
::
piecewise_construct
,
std
::
forward_as_tuple
(
token
),
std
::
make_tuple
<>
())
.
first
->
second
;
}
continue
;
}
PreprocessToken
(
token
);
auto
iter
=
cur
->
token_to_node
.
find
(
token
);
if
(
iter
!=
cur
->
token_to_node
.
end
())
{
cur
=
&
iter
->
second
;
}
else
{
cur
=
&
cur
->
token_to_node
.
emplace
(
std
::
piecewise_construct
,
std
::
forward_as_tuple
(
token
),
std
::
make_tuple
<>
())
.
first
->
second
;
}
}
cur
->
categories
.
insert
(
category
);
}
absl
::
flat_hash_set
<
int
>
SkipgramFinder
::
FindSkipgrams
(
absl
::
string_view
input
)
const
{
std
::
vector
<
std
::
string
>
tokens
=
absl
::
StrSplit
(
input
,
' '
);
std
::
vector
<
absl
::
string_view
>
sv_tokens
;
sv_tokens
.
reserve
(
tokens
.
size
());
for
(
auto
&
token
:
tokens
)
{
PreprocessToken
(
token
);
sv_tokens
.
emplace_back
(
token
.
data
(),
token
.
size
());
}
return
FindSkipgrams
(
sv_tokens
);
}
absl
::
flat_hash_set
<
int
>
SkipgramFinder
::
FindSkipgrams
(
const
std
::
vector
<
absl
::
string_view
>&
tokens
)
const
{
absl
::
flat_hash_set
<
int
>
categories
;
// Tracks skipgram prefixes and the index of their last token.
std
::
deque
<
std
::
pair
<
int
,
const
TrieNode
*>>
indices_and_skipgrams
;
for
(
int
token_i
=
0
;
token_i
<
tokens
.
size
();
token_i
++
)
{
const
absl
::
string_view
&
token
=
tokens
[
token_i
];
std
::
vector
<
absl
::
string_view
>
token_prefixes
;
{
const
char
*
s
=
token
.
data
();
int32_t
l
=
token
.
size
();
int32_t
n
=
0
;
while
(
n
<
l
)
{
int32_t
n_old
=
n
;
U8_FWD_1
(
s
,
n
,
l
);
if
(
n
==
n_old
)
break
;
token_prefixes
.
emplace_back
(
s
,
n
);
}
}
// Drop any skipgrams prefixes which would skip more than `max_skip_size_`
// tokens between the end of the prefix and the current token.
while
(
!
indices_and_skipgrams
.
empty
())
{
if
(
indices_and_skipgrams
.
front
().
first
+
max_skip_size_
+
1
<
token_i
)
{
indices_and_skipgrams
.
pop_front
();
}
else
{
break
;
}
}
// Check if we can form a valid skipgram prefix (or skipgram) by adding
// the current token to any of the existing skipgram prefixes, or
// if the current token is a valid skipgram prefix (or skipgram).
size_t
size
=
indices_and_skipgrams
.
size
();
for
(
size_t
skipgram_i
=
0
;
skipgram_i
<=
size
;
skipgram_i
++
)
{
const
auto
&
node
=
skipgram_i
<
size
?
*
indices_and_skipgrams
[
skipgram_i
].
second
:
skipgram_trie_
;
auto
iter
=
node
.
token_to_node
.
find
(
token
);
if
(
iter
!=
node
.
token_to_node
.
end
())
{
categories
.
insert
(
iter
->
second
.
categories
.
begin
(),
iter
->
second
.
categories
.
end
());
indices_and_skipgrams
.
push_back
(
std
::
make_pair
(
token_i
,
&
iter
->
second
));
}
for
(
auto
token_prefix
=
token_prefixes
.
rbegin
();
token_prefix
!=
token_prefixes
.
rend
();
token_prefix
++
)
{
auto
iter
=
node
.
prefix_to_node
.
find
(
*
token_prefix
);
if
(
iter
!=
node
.
prefix_to_node
.
end
())
{
categories
.
insert
(
iter
->
second
.
categories
.
begin
(),
iter
->
second
.
categories
.
end
());
indices_and_skipgrams
.
push_back
(
std
::
make_pair
(
token_i
,
&
iter
->
second
));
}
}
}
}
return
categories
;
}
}
// namespace seq_flow_lite
research/seq_flow_lite/tf_ops/skipgram_finder.h
0 → 100644
View file @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SKIPGRAM_FINDER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SKIPGRAM_FINDER_H_
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
namespace
seq_flow_lite
{
// SkipgramFinder finds skipgrams in strings.
//
// To use: First, add skipgrams using AddSkipgram() - each skipgram is
// associated with some category. Then, call FindSkipgrams() on a string,
// which will return the set of categories of the skipgrams in the string.
//
// Both the skipgrams and the input strings will be tokenzied by splitting
// on spaces. Additionally, the tokens will be lowercased and have any
// trailing punctuation removed.
class
SkipgramFinder
{
public:
explicit
SkipgramFinder
(
int
max_skip_size
)
:
max_skip_size_
(
max_skip_size
)
{}
// Adds a skipgram that SkipgramFinder should look for in input strings.
// Tokens may use the regex '.*' as a suffix.
void
AddSkipgram
(
absl
::
string_view
skipgram
,
int
category
);
// Find all of the skipgrams in `input`, and return their categories.
absl
::
flat_hash_set
<
int
>
FindSkipgrams
(
absl
::
string_view
input
)
const
;
// Find all of the skipgrams in `tokens`, and return their categories.
absl
::
flat_hash_set
<
int
>
FindSkipgrams
(
const
std
::
vector
<
absl
::
string_view
>&
tokens
)
const
;
private:
struct
TrieNode
{
absl
::
flat_hash_set
<
int
>
categories
;
// Maps tokens to the next node in the trie.
absl
::
flat_hash_map
<
std
::
string
,
TrieNode
>
token_to_node
;
// Maps token prefixes (<prefix>.*) to the next node in the trie.
absl
::
flat_hash_map
<
std
::
string
,
TrieNode
>
prefix_to_node
;
};
TrieNode
skipgram_trie_
;
int
max_skip_size_
;
};
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SKIPGRAM_FINDER_H_
research/seq_flow_lite/tf_ops/skipgram_finder_test.cc
0 → 100644
View file @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/skipgram_finder.h" // seq_flow_lite
#include <string>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "icu4c/source/common/unicode/utf8.h"
namespace
seq_flow_lite
{
namespace
{
using
::
testing
::
UnorderedElementsAreArray
;
void
TestFindSkipgrams
(
const
SkipgramFinder
&
skipgram_finder
,
const
std
::
vector
<
std
::
string
>&
tokens
,
const
std
::
vector
<
int
>&
categories
,
const
std
::
vector
<
int
>&
token_categories
)
{
EXPECT_THAT
(
skipgram_finder
.
FindSkipgrams
(
absl
::
StrJoin
(
tokens
,
" "
)),
UnorderedElementsAreArray
(
categories
));
std
::
vector
<
absl
::
string_view
>
sv_tokens
;
sv_tokens
.
reserve
(
tokens
.
size
());
for
(
const
auto
&
token
:
tokens
)
{
sv_tokens
.
emplace_back
(
token
.
data
(),
token
.
size
());
}
EXPECT_THAT
(
skipgram_finder
.
FindSkipgrams
(
sv_tokens
),
UnorderedElementsAreArray
(
token_categories
));
}
// Test that u_tolower() will only increase the number of bytes in the
// UTF-8 encoding in two specific cases.
TEST
(
SkipgramFinderTest
,
UCharToLower
)
{
for
(
UChar32
c
=
0
;
c
<
0x10000
;
c
++
)
{
if
(
c
==
0x23a
||
c
==
0x23e
)
continue
;
UChar32
l
=
u_tolower
(
c
);
EXPECT_GE
(
U8_LENGTH
(
c
),
U8_LENGTH
(
l
))
<<
c
<<
" lowercases to "
<<
l
;
}
}
TEST
(
SkipgramFinderTest
,
SingleExists
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"q r s"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"r"
,
"s"
,
"c"
},
{
0
},
{
0
});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"xyz"
,
"R!"
,
"xy"
,
"s"
,
"c"
},
{
0
},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"r"
,
"q"
,
"R"
,
"s."
,
"c"
},
{
0
},
{});
}
TEST
(
SkipgramFinderTest
,
SingleNotExists
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"q r s"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"x"
,
"x"
,
"r"
,
"x"
,
"s"
,
"c"
},
{},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"x"
,
"r"
,
"x"
,
"c"
},
{},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"r"
,
"x"
,
"s"
,
"q"
,
"c"
},
{},
{});
}
TEST
(
SkipgramFinderTest
,
SinglePrefixExists
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"q.* r s"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"qa"
,
"r"
,
"s"
,
"c"
},
{
0
},
{
0
});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"xyz"
,
"R!"
,
"xy"
,
"s"
,
"c"
},
{
0
},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"qc"
,
"r"
,
"qd"
,
"R"
,
"s."
,
"c"
},
{
0
},
{});
}
TEST
(
SkipgramFinderTest
,
SinglePrefixNotExists
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"q.* r s"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"aq"
,
"r"
,
"s"
,
"c"
},
{},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"aqc"
,
"xyz"
,
"R!"
,
"xy"
,
"s"
,
"c"
},
{},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"q"
,
"ar"
,
"q"
,
"aR"
,
"s."
,
"c"
},
{},
{});
}
TEST
(
SkipgramFinderTest
,
Punctuation
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"a-b-c def"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
TestFindSkipgrams
(
skipgram_finder
,
{
"q"
,
"abc"
,
"q"
,
"d-e-f"
,
"q"
},
{
0
},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"'abc'"
,
"q"
,
"'def'"
,
"q"
},
{
0
},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"q"
,
"abc"
,
"q"
,
"def"
,
"q"
},
{
0
},
{
0
});
}
TEST
(
SkipgramFinderTest
,
HandlesMultibyteInput
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"hello
\363\243\243\243
!"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
}
TEST
(
SkipgramFinderTest
,
Multiple
)
{
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s1
(
"a b c"
);
std
::
string
s2
(
"D e. F!"
);
std
::
string
s3
(
"ghi jkl mno"
);
std
::
string
s4
(
"S T U"
);
std
::
string
s5
(
"x. y, z!"
);
std
::
string
s6
(
"d.* e f"
);
skipgram_finder
.
AddSkipgram
(
s1
,
0
);
skipgram_finder
.
AddSkipgram
(
s2
,
2
);
skipgram_finder
.
AddSkipgram
(
s3
,
4
);
skipgram_finder
.
AddSkipgram
(
s4
,
6
);
skipgram_finder
.
AddSkipgram
(
s5
,
8
);
skipgram_finder
.
AddSkipgram
(
s6
,
10
);
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"d"
,
"b"
,
"e"
,
"c"
,
"f"
},
{
0
,
2
,
10
},
{
0
,
2
,
10
});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"dq"
,
"b"
,
"e"
,
"c"
,
"f"
},
{
0
,
10
},
{
0
,
10
});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"d"
,
"b"
,
"eq"
,
"c"
,
"f"
},
{
0
},
{
0
});
TestFindSkipgrams
(
skipgram_finder
,
{
"a"
,
"ghi"
,
"b"
,
"jkl"
,
"c"
,
"x"
,
"mno"
},
{
0
},
{
0
});
TestFindSkipgrams
(
skipgram_finder
,
{
"ghi"
,
"d"
,
"jkl"
,
"e"
,
"mno"
,
"f"
},
{
2
,
4
,
10
},
{
2
,
4
,
10
});
TestFindSkipgrams
(
skipgram_finder
,
{
"s"
,
"x"
,
"t"
,
"y"
,
"u"
,
"z"
},
{
6
,
8
},
{
6
,
8
});
}
TEST
(
SkipgramFinderTest
,
UnicodeLowercase
)
{
// Check that the lowercase has a smaller UTF-8 encoding than the uppercase.
UChar32
cu
;
U8_GET_UNSAFE
(
"Ɦ"
,
0
,
cu
);
UChar32
cl
=
u_tolower
(
cu
);
EXPECT_GT
(
U8_LENGTH
(
cu
),
U8_LENGTH
(
cl
));
SkipgramFinder
skipgram_finder
(
1
);
std
::
string
s
(
"Ɦ"
);
skipgram_finder
.
AddSkipgram
(
s
,
0
);
TestFindSkipgrams
(
skipgram_finder
,
{
"Ɦ"
},
{
0
},
{});
TestFindSkipgrams
(
skipgram_finder
,
{
"ɦ"
},
{
0
},
{
0
});
TestFindSkipgrams
(
skipgram_finder
,
{
"h"
},
{},
{});
}
}
// namespace
}
// namespace seq_flow_lite
research/seq_flow_lite/tf_ops/subsequence_finder.cc
0 → 100644
View file @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/subsequence_finder.h" // seq_flow_lite
#include <deque>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
#include "icu4c/source/common/unicode/utf8.h"
namespace
seq_flow_lite
{
void
SubsequenceFinder
::
AddSubsequence
(
absl
::
string_view
subsequence
,
int
category
)
{
const
char
*
s
=
subsequence
.
data
();
int32_t
length
=
subsequence
.
length
();
int32_t
n
=
0
;
TrieNode
*
trie
=
&
subsequence_trie_
;
bool
new_word
=
true
;
while
(
n
<
length
)
{
UChar32
c
;
U8_NEXT
(
s
,
n
,
length
,
c
);
if
(
c
<
0
)
return
;
c
=
u_tolower
(
c
);
if
(
c
==
' '
)
{
new_word
=
true
;
}
else
if
(
!
new_word
)
{
trie
=
&
trie
->
continue_token
[
c
];
}
else
{
trie
=
&
trie
->
next_token
[
c
];
new_word
=
false
;
}
}
trie
->
categories
.
insert
(
category
);
}
// Given a UChar32 and a trie node representing an in-progress subsequence,
// determine if we can use the UChar32 to continue the subsequence, and
// update `categories`, `next_tokens`, and `continue_tokens` if needed.
void
SubsequenceFinder
::
ProcessUChar32AndTrieNode
(
int
index
,
UChar32
c
,
const
absl
::
flat_hash_map
<
UChar32
,
TrieNode
>&
token_map
,
absl
::
flat_hash_set
<
int
>*
categories
,
std
::
deque
<
std
::
pair
<
int
,
const
TrieNode
*>>*
next_tokens
,
std
::
vector
<
const
TrieNode
*>*
continue_tokens
)
const
{
auto
iter
=
token_map
.
find
(
c
);
if
(
iter
!=
token_map
.
end
())
{
categories
->
insert
(
iter
->
second
.
categories
.
begin
(),
iter
->
second
.
categories
.
end
());
if
(
!
iter
->
second
.
continue_token
.
empty
())
{
continue_tokens
->
push_back
(
&
iter
->
second
);
}
if
(
!
iter
->
second
.
next_token
.
empty
())
{
next_tokens
->
emplace_back
(
index
,
&
iter
->
second
);
}
}
}
absl
::
flat_hash_set
<
int
>
SubsequenceFinder
::
FindSubsequences
(
absl
::
string_view
input
)
const
{
absl
::
flat_hash_set
<
int
>
categories
;
// Tracks subsequences in progress that are starting the next token,
// as well as the index of their last character.
std
::
deque
<
std
::
pair
<
int
,
const
TrieNode
*>>
next_tokens
;
// Tracks subsequences in progress that are looking for the next character
// in their corrent token. `current_continue_tokens` is the current set of
// subsequences being processed, while `future_continue_tokens` is the set
// of subsequences to process for the next character.
std
::
vector
<
const
TrieNode
*>
current_continue_tokens
;
std
::
vector
<
const
TrieNode
*>
future_continue_tokens
;
const
char
*
s
=
input
.
data
();
int32_t
length
=
input
.
length
();
int32_t
n
=
0
;
int
index
=
0
;
while
(
n
<
length
)
{
UChar32
c
;
U8_NEXT
(
s
,
n
,
length
,
c
);
if
(
c
<
0
)
return
categories
;
c
=
u_tolower
(
c
);
// Drop any subsequences which would need to skip more than `max_skip_size_`
// characters between the end of their last token and the current character.
while
(
!
next_tokens
.
empty
())
{
if
(
next_tokens
.
front
().
first
+
max_skip_size_
+
1
<
index
)
{
next_tokens
.
pop_front
();
}
else
{
break
;
}
}
// Check subsequences starting a new token.
size_t
size
=
next_tokens
.
size
();
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
ProcessUChar32AndTrieNode
(
index
,
c
,
next_tokens
[
i
].
second
->
next_token
,
&
categories
,
&
next_tokens
,
&
future_continue_tokens
);
}
// Check subsequences continuing a token.
for
(
const
TrieNode
*
continue_token
:
current_continue_tokens
)
{
ProcessUChar32AndTrieNode
(
index
,
c
,
continue_token
->
continue_token
,
&
categories
,
&
next_tokens
,
&
future_continue_tokens
);
}
// Check if we can start a new subsequence.
ProcessUChar32AndTrieNode
(
index
,
c
,
subsequence_trie_
.
next_token
,
&
categories
,
&
next_tokens
,
&
future_continue_tokens
);
current_continue_tokens
.
swap
(
future_continue_tokens
);
future_continue_tokens
.
clear
();
index
++
;
}
return
categories
;
}
}
// namespace seq_flow_lite
research/seq_flow_lite/tf_ops/subsequence_finder.h
0 → 100644
View file @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SUBSEQUENCE_FINDER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SUBSEQUENCE_FINDER_H_
#include <deque>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "icu4c/source/common/unicode/uchar.h"
namespace
seq_flow_lite
{
// SubsequenceFinder finds subsequences in UTF-8 strings.
//
// Specifically, given a subsequence t_1 t_2 ... t_n, we will check if a
// string matches '.*t_1.{0,N}t_2.{0,N} ... .{0,N}t_n.*', where N is the
// maximum skip size.
//
// To use: First, add subsequences using AddSubsequence() - each subsequence
// is associated with some category. Then call FindSubsequences() on a string,
// which will return the set of categories of the subsesequences in the string.
//
// The subsequences will be tokenized by splitting on spaces. Both subsequences
// and input strings will be normalized by lowercasing.
class
SubsequenceFinder
{
public:
explicit
SubsequenceFinder
(
int
max_skip_size
)
:
max_skip_size_
(
max_skip_size
)
{}
// Adds a subsequence that SubsequenceFinder should look for in input strings.
void
AddSubsequence
(
absl
::
string_view
subsequence
,
int
category
);
// Find all of the subsequences in `input`, and return their categories.
absl
::
flat_hash_set
<
int
>
FindSubsequences
(
absl
::
string_view
input
)
const
;
private:
// This trie tracks the next character needed to:
// * continue the current token
// * start the next token
struct
TrieNode
{
absl
::
flat_hash_set
<
int
>
categories
;
absl
::
flat_hash_map
<
UChar32
,
TrieNode
>
continue_token
;
absl
::
flat_hash_map
<
UChar32
,
TrieNode
>
next_token
;
};
void
ProcessUChar32AndTrieNode
(
int
index
,
UChar32
c
,
const
absl
::
flat_hash_map
<
UChar32
,
TrieNode
>&
token_map
,
absl
::
flat_hash_set
<
int
>*
categories
,
std
::
deque
<
std
::
pair
<
int
,
const
TrieNode
*>>*
next_tokens
,
std
::
vector
<
const
TrieNode
*>*
continue_tokens
)
const
;
TrieNode
subsequence_trie_
;
int
max_skip_size_
;
};
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TF_OPS_SUBSEQUENCE_FINDER_H_
research/seq_flow_lite/tf_ops/subsequence_finder_test.cc
0 → 100644
View file @
20cc2190
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tf_ops/subsequence_finder.h" // seq_flow_lite
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace
seq_flow_lite
{
namespace
{
using
::
testing
::
UnorderedElementsAre
;
TEST
(
SubsequenceFinderTest
,
SingleExists
)
{
SubsequenceFinder
subsequence_finder
(
3
);
subsequence_finder
.
AddSubsequence
(
"ab cd"
,
0
);
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"abcd"
),
UnorderedElementsAre
(
0
));
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"ab012cd"
),
UnorderedElementsAre
(
0
));
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"AB CD"
),
UnorderedElementsAre
(
0
));
}
TEST
(
SubsequenceFinderTest
,
SingleNotExists
)
{
SubsequenceFinder
subsequence_finder
(
3
);
subsequence_finder
.
AddSubsequence
(
"ab cd"
,
0
);
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"a bcd"
),
UnorderedElementsAre
());
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"ab0123cd"
),
UnorderedElementsAre
());
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"abdc"
),
UnorderedElementsAre
());
}
TEST
(
SubsequenceFinderTest
,
Multiple
)
{
SubsequenceFinder
subsequence_finder
(
3
);
subsequence_finder
.
AddSubsequence
(
"a b c d"
,
0
);
subsequence_finder
.
AddSubsequence
(
"q r s"
,
2
);
subsequence_finder
.
AddSubsequence
(
"b c d e"
,
4
);
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"a__b__c__d__e"
),
UnorderedElementsAre
(
0
,
4
));
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"aqbrcsd"
),
UnorderedElementsAre
(
0
,
2
));
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"b q c r d s e"
),
UnorderedElementsAre
(
2
,
4
));
}
TEST
(
SubsequenceFinderTest
,
Utf8
)
{
SubsequenceFinder
subsequence_finder
(
3
);
subsequence_finder
.
AddSubsequence
(
"一二 三四 五六"
,
0
);
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"一二おはよ三四こんに五六"
),
UnorderedElementsAre
(
0
));
EXPECT_THAT
(
subsequence_finder
.
FindSubsequences
(
"一二三 四五六"
),
UnorderedElementsAre
());
}
}
// namespace
}
// namespace seq_flow_lite
research/seq_flow_lite/tf_ops/text_distorter.h
View file @
20cc2190
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_TEXT_DISTORTER_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_TEXT_DISTORTER_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_TEXT_DISTORTER_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_TEXT_DISTORTER_H_
#include <assert.h>
#include <assert.h>
...
@@ -40,4 +40,4 @@ class TextDistorter {
...
@@ -40,4 +40,4 @@ class TextDistorter {
UChar32
random_char_
=
0
;
UChar32
random_char_
=
0
;
};
};
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TF_OPS_TEXT_DISTORTER_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TF_OPS_TEXT_DISTORTER_H_
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
View file @
20cc2190
...
@@ -122,4 +122,4 @@ REGISTER_OP("UniformCausalAttn")
...
@@ -122,4 +122,4 @@ REGISTER_OP("UniformCausalAttn")
})
})
.
Doc
(
R"doc(
.
Doc
(
R"doc(
Dummy uniform causal attn op.
Dummy uniform causal attn op.
)doc"
;
)doc"
)
;
research/seq_flow_lite/tflite_ops/BUILD
View file @
20cc2190
...
@@ -121,9 +121,9 @@ cc_library(
...
@@ -121,9 +121,9 @@ cc_library(
hdrs
=
[
"tflite_qrnn_pooling.h"
],
hdrs
=
[
"tflite_qrnn_pooling.h"
],
copts
=
tflite_copts
(),
copts
=
tflite_copts
(),
deps
=
[
deps
=
[
"
//third_party/absl/base:core_header
s"
,
"
@org_tensorflow//tensorflow/lite/kernels:builtin_op
s"
,
"//t
hird_party/tensorflow/lite/kernels:buil
tin_
ops"
,
"//t
flite_ops:quantiza
ti
o
n_
util"
,
# sequence projection
"
//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util
"
,
"
@com_google_absl//absl/base:core_headers
"
,
],
],
alwayslink
=
1
,
alwayslink
=
1
,
)
)
...
@@ -132,7 +132,7 @@ cc_library(
...
@@ -132,7 +132,7 @@ cc_library(
name
=
"tflite_decoder_cache"
,
name
=
"tflite_decoder_cache"
,
hdrs
=
[
"tflite_decoder_cache.h"
],
hdrs
=
[
"tflite_decoder_cache.h"
],
deps
=
[
deps
=
[
"
//third_party
/tensorflow/lite/c:common"
,
"
@org_tensorflow/
/tensorflow/lite/c:common"
,
],
],
alwayslink
=
1
,
alwayslink
=
1
,
)
)
...
@@ -144,12 +144,12 @@ cc_library(
...
@@ -144,12 +144,12 @@ cc_library(
copts
=
tflite_copts
(),
copts
=
tflite_copts
(),
deps
=
[
deps
=
[
":tflite_decoder_cache"
,
":tflite_decoder_cache"
,
"
//third_party/flatbuffers
"
,
"
@org_tensorflow//tensorflow/lite/c:common
"
,
"
//third_party
/tensorflow/lite/
c:common
"
,
"
@org_tensorflow/
/tensorflow/lite/
kernels:builtin_ops
"
,
"
//third_party
/tensorflow/lite/kernels:
builtin_ops
"
,
"
@org_tensorflow/
/tensorflow/lite/kernels:
kernel_util
"
,
"
//third_party
/tensorflow/lite/kernels
:kernel_util
"
,
"
@org_tensorflow/
/tensorflow/lite/kernels
/internal:tensor
"
,
"//t
hird_party/tensorflow/lite/kernels/internal:tensor"
,
"//t
flite_ops:quantization_util"
,
# sequence projection
"
//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util
"
,
"
@flatbuffers
"
,
],
],
alwayslink
=
1
,
alwayslink
=
1
,
)
)
...
@@ -160,11 +160,11 @@ cc_test(
...
@@ -160,11 +160,11 @@ cc_test(
srcs
=
[
"tflite_decoder_handler_test.cc"
],
srcs
=
[
"tflite_decoder_handler_test.cc"
],
deps
=
[
deps
=
[
":tflite_decoder_handler"
,
":tflite_decoder_handler"
,
"
//testing/base/public:gunit
"
,
"
@org_tensorflow//tensorflow/lite:framework
"
,
"
//third_party/flatbuffers
"
,
"
@org_tensorflow//tensorflow/lite/c:common
"
,
"
//third_party/tensorflow/lite:framework
"
,
"
@org_tensorflow//tensorflow/lite/kernels:test_util
"
,
"
//third_party/tensorflow/lite/c:common
"
,
"
@com_google_googletest//:gtest
"
,
"
//third_party/tensorflow/lite/kernels:test_util
"
,
"
@flatbuffers
"
,
],
],
)
)
...
@@ -176,10 +176,10 @@ cc_library(
...
@@ -176,10 +176,10 @@ cc_library(
deps
=
[
deps
=
[
"//base"
,
"//base"
,
"//third_party/absl/strings"
,
"//third_party/absl/strings"
,
"
//third_party
/tensorflow/lite/c:common"
,
"
@org_tensorflow/
/tensorflow/lite/c:common"
,
"
//third_party
/tensorflow/lite/kernels/internal:tensor"
,
"
@org_tensorflow/
/tensorflow/lite/kernels/internal:tensor"
,
"
//third_party
/tensorflow/lite/kernels/internal:types"
,
"
@org_tensorflow/
/tensorflow/lite/kernels/internal:types"
,
"//
third_party/tensorflow_models/seq_flow_lite/
tflite_ops:quantization_util"
,
"//tflite_ops:quantization_util"
,
# sequence projection
],
],
)
)
...
@@ -189,14 +189,14 @@ cc_test(
...
@@ -189,14 +189,14 @@ cc_test(
copts
=
tflite_copts
(),
copts
=
tflite_copts
(),
deps
=
[
deps
=
[
":beam_search"
,
":beam_search"
,
"//testing/base/public:gunit_main"
,
"//third_party/absl/strings"
,
"//third_party/absl/strings"
,
"//third_party/tensorflow/lite/c:c_api_types"
,
"@org_tensorflow//tensorflow/lite/c:c_api_types"
,
"//third_party/tensorflow/lite/c:common"
,
"@org_tensorflow//tensorflow/lite/c:common"
,
"//third_party/tensorflow/lite/kernels/internal:legacy_reference_base"
,
"@org_tensorflow//tensorflow/lite/kernels/internal:legacy_reference_base"
,
"//third_party/tensorflow/lite/kernels/internal:optimized_base"
,
"@org_tensorflow//tensorflow/lite/kernels/internal:optimized_base"
,
"//third_party/tensorflow/lite/kernels/internal:tensor"
,
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor"
,
"//third_party/tensorflow/lite/kernels/internal:types"
,
"@org_tensorflow//tensorflow/lite/kernels/internal:types"
,
"//third_party/tensorflow_models/seq_flow_lite/tflite_ops:quantization_util"
,
"//tflite_ops:quantization_util"
,
# sequence projection
"@com_google_googletest//:gtest_main"
,
],
],
)
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment