Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
51f4ecad
Unverified
Commit
51f4ecad
authored
Nov 25, 2020
by
prabhukaliamoorthi
Committed by
GitHub
Nov 25, 2020
Browse files
Add demo app and update op handlers (#9503)
parent
7310b0f8
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
381 additions
and
96 deletions
+381
-96
research/seq_flow_lite/tf_ops/sequence_string_projection_op_v2.cc
.../seq_flow_lite/tf_ops/sequence_string_projection_op_v2.cc
+25
-12
research/seq_flow_lite/tf_ops/text_distorter.h
research/seq_flow_lite/tf_ops/text_distorter.h
+1
-0
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
+35
-5
research/seq_flow_lite/tflite_ops/BUILD
research/seq_flow_lite/tflite_ops/BUILD
+3
-2
research/seq_flow_lite/tflite_ops/layer_norm.cc
research/seq_flow_lite/tflite_ops/layer_norm.cc
+57
-29
research/seq_flow_lite/tflite_ops/sequence_string_projection.cc
...ch/seq_flow_lite/tflite_ops/sequence_string_projection.cc
+100
-20
research/seq_flow_lite/tflite_ops/sequence_string_projection_test.cc
...q_flow_lite/tflite_ops/sequence_string_projection_test.cc
+121
-6
research/seq_flow_lite/trainer.py
research/seq_flow_lite/trainer.py
+22
-12
research/seq_flow_lite/utils/tflite_utils.py
research/seq_flow_lite/utils/tflite_utils.py
+17
-10
No files found.
research/seq_flow_lite/tf_ops/sequence_string_projection_op_v2.cc
View file @
51f4ecad
...
...
@@ -40,12 +40,16 @@ constexpr char kEndTokenTSP[] = "<EOS>";
constexpr
float
kMappingTable
[
4
]
=
{
0
,
1
,
-
1
,
0
};
constexpr
int
kIncrement
=
32
;
template
<
typename
T
>
class
SequenceStringProjectionOpV2
:
public
OpKernel
{
public:
explicit
SequenceStringProjectionOpV2
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"feature_size"
,
&
feature_size_
));
hasher_
=
absl
::
make_unique
<
Hasher
>
(
feature_size_
);
std
::
string
hashtype
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"hashtype"
,
&
hashtype
));
hasher_
=
absl
::
WrapUnique
<
Hasher
>
(
Hasher
::
CreateHasher
(
feature_size_
,
hashtype
));
float
distortion_probability
=
0.0
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"distortion_probability"
,
...
...
@@ -88,13 +92,13 @@ class SequenceStringProjectionOpV2 : public OpKernel {
ctx
,
TensorShapeUtils
::
IsVector
(
seq_len
->
shape
()),
InvalidArgument
(
"`sequence_length` must be a vector, got shape: "
,
seq_len
->
shape
().
DebugString
()));
auto
seq_len_
vector
=
seq_len
->
vec
<
int32
>
();
auto
seq_len_
flat
=
seq_len
->
flat
<
T
>
();
OP_REQUIRES
(
ctx
,
seq_len_
vector
.
size
()
==
batch_size
,
ctx
,
seq_len_
flat
.
size
()
==
batch_size
,
InvalidArgument
(
"`sequence_length` should have batch size number "
"of elements, got size "
,
seq_len_
vector
.
size
(),
", batch size is "
,
batch_size
));
seq_len_
flat
.
size
(),
", batch size is "
,
batch_size
));
Tensor
*
output_tensor
=
nullptr
;
OP_REQUIRES_OK
(
...
...
@@ -106,7 +110,7 @@ class SequenceStringProjectionOpV2 : public OpKernel {
std
::
vector
<
uint64_t
>
hash_codes
;
for
(
int64
i
=
0
;
i
<
batch_size
;
++
i
)
{
const
int64
num_tokens
=
seq_len_
vector
(
i
);
const
int64
num_tokens
=
seq_len_
flat
(
i
);
OP_REQUIRES
(
ctx
,
num_tokens
>=
0
,
InvalidArgument
(
"`sequence_length` should have values "
"greater than or equal to 0"
));
...
...
@@ -159,20 +163,27 @@ class SequenceStringProjectionOpV2 : public OpKernel {
int
bos_tag_
;
};
REGISTER_KERNEL_BUILDER
(
Name
(
"SequenceStringProjectionV2"
).
Device
(
::
tensorflow
::
DEVICE_CPU
),
SequenceStringProjectionOpV2
);
REGISTER_KERNEL_BUILDER
(
Name
(
"SequenceStringProjectionV2"
)
.
Device
(
::
tensorflow
::
DEVICE_CPU
)
.
TypeConstraint
<
int64
>
(
"Tsequence_length"
),
SequenceStringProjectionOpV2
<
int64
>
);
REGISTER_KERNEL_BUILDER
(
Name
(
"SequenceStringProjectionV2"
)
.
Device
(
::
tensorflow
::
DEVICE_CPU
)
.
TypeConstraint
<
int32
>
(
"Tsequence_length"
),
SequenceStringProjectionOpV2
<
int32
>
);
REGISTER_OP
(
"SequenceStringProjectionV2"
)
.
Input
(
"input: string"
)
.
Input
(
"sequence_length:
int32
"
)
.
Input
(
"sequence_length:
Tsequence_length
"
)
.
Output
(
"projection: float32"
)
.
Attr
(
"feature_size: int"
)
.
Attr
(
"distortion_probability: float = 0.0"
)
.
Attr
(
"vocabulary: string = ''"
)
.
Attr
(
"hashtype: string = 'murmur'"
)
.
Attr
(
"add_bos_tag: bool = False"
)
.
Attr
(
"add_eos_tag: bool = False"
)
.
Attr
(
"normalize_repetition: bool = False"
)
.
Attr
(
"Tsequence_length: {int32, int64}"
)
.
SetShapeFn
([](
::
tensorflow
::
shape_inference
::
InferenceContext
*
c
)
{
DimensionHandle
size
;
...
...
@@ -181,9 +192,10 @@ REGISTER_OP("SequenceStringProjectionV2")
const
int
kMaxFeatureSize
=
4096
;
CHECK_GT
(
feature_size
,
0
);
CHECK_LE
(
feature_size
,
kMaxFeatureSize
);
auto
batch_size
=
c
->
Dim
(
c
->
input
(
0
),
0
);
c
->
set_output
(
0
,
c
->
MakeShape
({
batch_size
,
InferenceContext
::
kUnknownDim
,
feature_size
}));
ShapeHandle
output_shape
;
TF_RETURN_IF_ERROR
(
c
->
Concatenate
(
c
->
input
(
0
),
c
->
MakeShape
({
feature_size
}),
&
output_shape
));
c
->
set_output
(
0
,
output_shape
);
return
tensorflow
::
Status
::
OK
();
})
.
Doc
(
R"doc(
...
...
@@ -209,6 +221,7 @@ Attribute(s):
will be allowed in the input text before fingerprinting. Expressed another
way the vocabulary is an optional character allowlist for the
input tokens. It helps normalize the text.
- hashtype: Hashing method to use for projection.
- add_bos_tag: When true inserts a begin of sentence tag.
- add_eos_tag: When true inserts a end of sentence tag.
- normalize_repetition: When true normalizes repetition in text tokens before
...
...
research/seq_flow_lite/tf_ops/text_distorter.h
View file @
51f4ecad
...
...
@@ -32,6 +32,7 @@ class TextDistorter {
assert
(
distortion_probability_
<=
1.0
);
}
std
::
string
DistortText
(
icu
::
UnicodeString
*
uword
);
bool
BernouilleSample
(
float
p
)
{
return
(
generator_
.
RandFloat
()
<=
p
);
}
private:
tensorflow
::
random
::
PhiloxRandom
philox_
;
...
...
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
View file @
51f4ecad
...
...
@@ -32,8 +32,8 @@ REGISTER_KERNEL_BUILDER(Name("PoolingOp").Device(::tensorflow::DEVICE_CPU),
REGISTER_OP
(
"PoolingOp"
)
.
Input
(
"multiplier: float32"
)
.
Input
(
"constant: float32"
)
.
Input
(
"forward: float32"
)
.
Output
(
"state: float32"
)
.
Attr
(
"forward: bool"
)
.
SetShapeFn
([](
::
tensorflow
::
shape_inference
::
InferenceContext
*
c
)
{
c
->
set_output
(
0
,
c
->
input
(
0
));
return
tensorflow
::
Status
::
OK
();
...
...
@@ -75,14 +75,14 @@ class LayerNormOp : public tensorflow::OpKernel {
void
Compute
(
tensorflow
::
OpKernelContext
*
ctx
)
override
{}
};
REGISTER_KERNEL_BUILDER
(
Name
(
"LayerNorm
V2
"
).
Device
(
::
tensorflow
::
DEVICE_CPU
),
REGISTER_KERNEL_BUILDER
(
Name
(
"LayerNorm"
).
Device
(
::
tensorflow
::
DEVICE_CPU
),
LayerNormOp
);
REGISTER_OP
(
"LayerNorm
V2
"
)
REGISTER_OP
(
"LayerNorm"
)
.
Input
(
"tensor: float32"
)
.
Input
(
"scale: float32"
)
.
Input
(
"o
utpu
t: float32"
)
.
Attr
(
"axes:
list(
int
)
"
)
.
Input
(
"o
ffse
t: float32"
)
.
Input
(
"axes: int
32
"
)
.
Output
(
"result: float32"
)
.
SetShapeFn
([](
::
tensorflow
::
shape_inference
::
InferenceContext
*
c
)
{
c
->
set_output
(
0
,
c
->
input
(
0
));
...
...
@@ -91,3 +91,33 @@ REGISTER_OP("LayerNormV2")
.
Doc
(
R"doc(
Dummy layer norm op.
)doc"
);
class
UniformCausalAttnOp
:
public
tensorflow
::
OpKernel
{
public:
explicit
UniformCausalAttnOp
(
tensorflow
::
OpKernelConstruction
*
context
)
:
tensorflow
::
OpKernel
(
context
)
{}
void
Compute
(
tensorflow
::
OpKernelContext
*
ctx
)
override
{}
};
REGISTER_KERNEL_BUILDER
(
Name
(
"UniformCausalAttn"
).
Device
(
::
tensorflow
::
DEVICE_CPU
),
UniformCausalAttnOp
);
REGISTER_OP
(
"UniformCausalAttn"
)
.
Input
(
"input: float32"
)
.
Input
(
"time_step: int32"
)
.
Input
(
"selected_beams: int32"
)
.
Attr
(
"feature_size: int"
)
.
Attr
(
"beam_size: int"
)
.
Output
(
"output: float32"
)
.
SetShapeFn
([](
::
tensorflow
::
shape_inference
::
InferenceContext
*
c
)
{
auto
batch_size
=
c
->
Dim
(
c
->
input
(
0
),
0
);
int32
feature_size
;
TF_RETURN_IF_ERROR
(
c
->
GetAttr
(
"feature_size"
,
&
feature_size
));
c
->
set_output
(
0
,
c
->
MakeShape
({
batch_size
,
1
,
feature_size
}));
return
tensorflow
::
Status
::
OK
();
})
.
Doc
(
R"doc(
Dummy uniform causal attn op.
)doc"
);
research/seq_flow_lite/tflite_ops/BUILD
View file @
51f4ecad
...
...
@@ -36,8 +36,9 @@ cc_test(
"@org_tensorflow//tensorflow/lite/core/api"
,
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops"
,
"@org_tensorflow//tensorflow/lite/kernels:test_util"
,
"//tf_ops:projection_util"
,
# sequence projection
# "//tf_ops:sequence_string_projection_op" # sequence projection
"//tf_ops:sequence_string_projection_op_v2"
,
# sequence projection
#
"//tf_ops:sequence_string_projection_op_v2" # sequence projection
],
)
...
...
@@ -83,7 +84,7 @@ cc_library(
deps
=
[
":quantization_util"
,
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops"
,
"@
flatbuffers
"
,
"@
org_tensorflow//tensorflow/lite/kernels:kernel_util
"
,
],
alwayslink
=
1
,
)
...
...
research/seq_flow_lite/tflite_ops/layer_norm.cc
View file @
51f4ecad
...
...
@@ -17,8 +17,8 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "tensorflow/lite/kernels/kernel_util.h"
namespace
tflite
{
namespace
ops
{
...
...
@@ -213,6 +213,34 @@ TfLiteStatus FlexibleLayerNorm(const TfLiteTensor* input, const float scale,
return
kTfLiteOk
;
}
TfLiteStatus
DefaultLayerNormFloat
(
const
TfLiteTensor
*
input
,
const
float
scale
,
const
float
offset
,
TfLiteTensor
*
output
)
{
const
int
input_rank
=
input
->
dims
->
size
;
const
int
num_features
=
input
->
dims
->
data
[
input_rank
-
1
];
const
int
time_steps
=
static_cast
<
int
>
(
GetNumberOfSteps
(
input
)
/
num_features
);
float
*
out_ptr
=
output
->
data
.
f
;
for
(
int
i
=
0
;
i
<
time_steps
;
++
i
)
{
float
sum_x
=
0
;
float
sum_xx
=
0
;
for
(
int
j
=
0
,
index
=
i
*
num_features
;
j
<
num_features
;
++
j
,
++
index
)
{
sum_x
+=
input
->
data
.
f
[
index
];
sum_xx
+=
input
->
data
.
f
[
index
]
*
input
->
data
.
f
[
index
];
}
const
float
exp_xx
=
sum_xx
/
num_features
;
const
float
exp_x
=
sum_x
/
num_features
;
const
float
variance
=
exp_xx
-
exp_x
*
exp_x
;
const
float
inverse_stddev
=
1
/
sqrt
(
variance
+
1e-6
);
const
float
multiplier
=
inverse_stddev
*
scale
;
const
float
bias
=
offset
-
exp_x
*
inverse_stddev
*
scale
;
for
(
int
j
=
0
,
index
=
i
*
num_features
;
j
<
num_features
;
++
j
,
++
index
)
{
out_ptr
[
index
]
=
input
->
data
.
f
[
index
]
*
multiplier
+
bias
;
}
}
return
kTfLiteOk
;
}
TfLiteStatus
DefaultLayerNorm
(
const
TfLiteTensor
*
input
,
const
float
scale
,
const
float
offset
,
TfLiteTensor
*
output
)
{
const
int
input_rank
=
input
->
dims
->
size
;
...
...
@@ -250,25 +278,40 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const
TfLiteTensor
*
input
=
&
context
->
tensors
[
node
->
inputs
->
data
[
kInputIndex
]];
TfLiteTensor
*
output
=
&
context
->
tensors
[
node
->
outputs
->
data
[
kOutputIndex
]];
const
float
scale
=
PodDequantize
(
context
->
tensors
[
node
->
inputs
->
data
[
kScaleIndex
]],
0
);
const
float
offset
=
PodDequantize
(
context
->
tensors
[
node
->
inputs
->
data
[
kOffsetIndex
]],
0
);
const
std
::
vector
<
int
>&
axes
=
*
reinterpret_cast
<
std
::
vector
<
int
>*>
(
node
->
user_data
);
const
size_t
num_axis
=
axes
.
size
();
TfLiteTensor
scale_tensor
=
context
->
tensors
[
node
->
inputs
->
data
[
kScaleIndex
]];
TfLiteTensor
offset_tensor
=
context
->
tensors
[
node
->
inputs
->
data
[
kOffsetIndex
]];
float
scale
=
1.0
;
float
offset
=
0.0
;
if
(
input
->
type
==
kTfLiteUInt8
)
{
scale
=
PodDequantize
(
scale_tensor
,
0
);
offset
=
PodDequantize
(
offset_tensor
,
0
);
}
else
{
scale
=
scale_tensor
.
data
.
f
[
0
];
offset
=
offset_tensor
.
data
.
f
[
0
];
}
TfLiteTensor
*
axis
=
&
context
->
tensors
[
node
->
inputs
->
data
[
kAxisIndex
]];
int
num_axis
=
static_cast
<
int
>
(
tflite
::
NumElements
(
axis
));
// For backward compatibility reasons, we handle the default layer norm for
// last channel as below.
if
(
num_axis
==
1
&&
(
axes
[
0
]
==
-
1
||
axes
[
0
]
==
(
input
->
dims
->
size
-
1
)))
{
return
DefaultLayerNorm
(
input
,
scale
,
offset
,
output
);
if
(
num_axis
==
1
&&
(
axis
->
data
.
i32
[
0
]
==
-
1
||
axis
->
data
.
i32
[
0
]
==
(
input
->
dims
->
size
-
1
)))
{
if
(
input
->
type
==
kTfLiteUInt8
)
{
return
DefaultLayerNorm
(
input
,
scale
,
offset
,
output
);
}
else
if
(
input
->
type
==
kTfLiteFloat32
)
{
return
DefaultLayerNormFloat
(
input
,
scale
,
offset
,
output
);
}
else
{
TF_LITE_ENSURE_MSG
(
context
,
false
,
"Input should be eith Uint8 or Float32."
);
}
}
std
::
vector
<
int
>
resolved_axis
(
num_axis
);
// Resolve axis.
int
num_resolved_axis
=
0
;
if
(
!
ResolveAxis
(
input
->
dims
->
size
,
ax
es
.
data
(),
num_axis
,
&
resolved_axis
[
0
]
,
&
num_resolved_axis
))
{
if
(
!
ResolveAxis
(
input
->
dims
->
size
,
ax
is
->
data
.
i32
,
num_axis
,
&
resolved_axis
[
0
],
&
num_resolved_axis
))
{
return
kTfLiteError
;
}
...
...
@@ -276,25 +319,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
num_resolved_axis
,
output
);
}
void
*
Init
(
TfLiteContext
*
context
,
const
char
*
buffer
,
size_t
length
)
{
const
uint8_t
*
buffer_t
=
reinterpret_cast
<
const
uint8_t
*>
(
buffer
);
const
flexbuffers
::
Map
&
m
=
flexbuffers
::
GetRoot
(
buffer_t
,
length
).
AsMap
();
std
::
vector
<
int
>*
axes
=
new
std
::
vector
<
int
>
();
auto
axes_fb
=
m
[
"axes"
].
AsTypedVector
();
for
(
int
i
=
0
;
i
<
axes_fb
.
size
();
++
i
)
{
axes
->
push_back
(
axes_fb
[
i
].
AsInt32
());
}
return
axes
;
}
void
Free
(
TfLiteContext
*
context
,
void
*
buffer
)
{
delete
reinterpret_cast
<
std
::
vector
<
int
>*>
(
buffer
);
}
}
// namespace
TfLiteRegistration
*
Register_LAYER_NORM
()
{
static
TfLiteRegistration
r
=
{
Init
,
Free
,
Resize
,
Eval
};
static
TfLiteRegistration
r
=
{
nullptr
,
nullptr
,
Resize
,
Eval
};
return
&
r
;
}
...
...
research/seq_flow_lite/tflite_ops/sequence_string_projection.cc
View file @
51f4ecad
...
...
@@ -67,6 +67,14 @@ namespace sequence_string_projection {
* true. Defaults to true.
* attribute[7]: add_bos_tag, add a begin of sequence tag to the output when
* true. Defaults to false.
* attribute[8]: add_first_cap_feature, when set to 1.0f add a feature to the
* resulting projection tensor that helps discriminate if the
* input token is Camel case. Otherwise leaves the projection
* output unmodified.
* attribute[9]: add_all_caps_feature, when set to 1.0f add a feature to the
* resulting projection tensor that helps discriminate if the
* input token is ALLCAPS. Otherwise leaves the projection
* output unmodified.
* Output:
* tensor[0]: computed projections.
* float32[true number of tokens][feature size]
...
...
@@ -87,22 +95,32 @@ enum class EosTag { kGenerate, kNone };
class
ProjectionParams
{
public:
ProjectionParams
(
int
feature_size
,
const
std
::
string
&
vocabulary
,
int
max_splits
,
bool
split_on_space
,
int
word_novelty_bits
,
const
std
::
string
&
hashtype
,
int
max_splits
,
bool
split_on_space
,
int
word_novelty_bits
,
int
doc_size_levels
,
BosTag
add_bos_tag
,
EosTag
add_eos_tag
,
bool
exclude_nonalphaspace_unicodes
,
const
std
::
string
&
token_separators
,
bool
normalize_repetition
)
bool
normalize_repetition
,
bool
add_first_cap_feature
,
bool
add_all_caps_feature
)
:
feature_size_
(
feature_size
),
unicode_handler_
(
vocabulary
,
exclude_nonalphaspace_unicodes
),
hasher_
(
feature_size
),
hasher_
(
Hasher
::
CreateHasher
(
feature_size
,
hashtype
)
),
max_splits_
(
max_splits
),
split_on_space_
(
split_on_space
),
word_novelty_bits_
(
word_novelty_bits
),
doc_size_levels_
(
doc_size_levels
),
add_bos_tag_
(
add_bos_tag
==
BosTag
::
kGenerate
),
add_eos_tag_
(
add_eos_tag
==
EosTag
::
kGenerate
)
{
add_eos_tag_
(
add_eos_tag
==
EosTag
::
kGenerate
),
add_first_cap_feature_
(
add_first_cap_feature
),
add_all_caps_feature_
(
add_all_caps_feature
)
{
assert
(
max_splits_
==
-
1
||
max_splits_
>
0
);
assert
(
word_novelty_bits
>=
0
&&
word_novelty_bits
<=
7
);
// hasher_ can be nullptr if the hashtype is invalid. But there is a similar
// check in tensorflow op when the model is created. So this failure will
// never happen if the model was successfully trained. Still adding a check
// here since you can edit the model post training, which is the only
// situation when this assertion will fail.
assert
(
hasher_
!=
nullptr
);
if
(
word_novelty_bits_
!=
0
)
{
assert
(
feature_size_
>=
1
);
}
...
...
@@ -113,8 +131,8 @@ class ProjectionParams {
word_novelty_offset_
=
2.0
f
/
(
1
<<
word_novelty_bits_
);
if
(
!
token_separators
.
empty
()
||
normalize_repetition
)
{
projection_normalizer_
.
reset
(
new
ProjectionNormalizer
(
token_separators
,
normalize_repetition
)
)
;
projection_normalizer_
=
std
::
make_unique
<
ProjectionNormalizer
>
(
token_separators
,
normalize_repetition
);
}
}
virtual
~
ProjectionParams
()
{}
...
...
@@ -129,6 +147,8 @@ class ProjectionParams {
*
data
=
PodQuantize
(
word_novelty_feature
,
127.0
f
,
127
);
}
bool
DocSizeFeatureEnabled
()
const
{
return
(
doc_size_levels_
!=
0
);
}
bool
FirstCap
()
const
{
return
add_first_cap_feature_
;
}
bool
AllCaps
()
const
{
return
add_all_caps_feature_
;
}
int
BosToken
()
const
{
return
add_bos_tag_
?
1
:
0
;
}
int
EosToken
()
const
{
return
add_eos_tag_
?
1
:
0
;
}
void
DocSizeFeature
(
float
*
data
,
int
num_tokens
)
{
...
...
@@ -144,13 +164,15 @@ class ProjectionParams {
*
data
=
PodQuantize
(
doc_size_feature
,
127.0
f
,
127
);
}
void
Hash
(
const
std
::
string
&
word
,
std
::
vector
<
uint64_t
>*
hash_codes
)
{
hasher_
.
GetHashCodes
(
word
,
hash_codes
);
hasher_
->
GetHashCodes
(
word
,
hash_codes
);
}
// Lower cases the input text and eliminates all unsupported
// unicodes in it if a vocabulary is provided.
std
::
string
LowerCaseUTF8WithSupportedUnicodes
(
std
::
pair
<
const
char
*
,
size_t
>
source
)
const
{
return
unicode_handler_
.
LowerCaseUTF8WithSupportedUnicodes
(
source
);
std
::
pair
<
const
char
*
,
size_t
>
source
,
bool
*
first_cap
,
bool
*
all_caps
)
const
{
return
unicode_handler_
.
LowerCaseUTF8WithSupportedUnicodes
(
source
,
first_cap
,
all_caps
);
}
// Splits the input text into a set of tokens. Uses space as the delimiter
// when split_on_space is True and unicode boundaries as the delimiter
...
...
@@ -190,13 +212,15 @@ class ProjectionParams {
private:
int
feature_size_
;
ProjectionUnicodeHandler
unicode_handler_
;
Hasher
hasher_
;
std
::
unique_ptr
<
Hasher
>
hasher_
;
int
max_splits_
;
bool
split_on_space_
;
int
word_novelty_bits_
;
int
doc_size_levels_
;
bool
add_bos_tag_
;
bool
add_eos_tag_
;
bool
add_first_cap_feature_
;
bool
add_all_caps_feature_
;
float
word_novelty_offset_
;
std
::
string
normalized_input_
;
...
...
@@ -208,14 +232,17 @@ class ProjectionParams {
class
ProjectionParamsV2
:
public
ProjectionParams
{
public:
ProjectionParamsV2
(
int
feature_size
,
const
std
::
string
&
vocabulary
,
BosTag
add_bos_tag
,
EosTag
add_eos_tag
,
bool
normalize_repetition
)
:
ProjectionParams
(
feature_size
,
vocabulary
,
/*max_splits = */
-
1
,
const
std
::
string
&
hashtype
,
BosTag
add_bos_tag
,
EosTag
add_eos_tag
,
bool
normalize_repetition
)
:
ProjectionParams
(
feature_size
,
vocabulary
,
hashtype
,
/*max_splits = */
-
1
,
/* split_on_space = */
true
,
/*word_novelty_bits = */
0
,
/*doc_size_levels = */
0
,
add_bos_tag
,
add_eos_tag
,
/*exclude_nonalphaspace_unicodes = */
false
,
/*token_separators = */
""
,
normalize_repetition
)
{}
/*token_separators = */
""
,
normalize_repetition
,
/*add_first_cap_feature = */
false
,
/*add_all_caps_feature = */
false
)
{}
~
ProjectionParamsV2
()
override
{}
TfLiteStatus
PreprocessInput
(
TfLiteTensor
*
input_t
,
...
...
@@ -271,6 +298,8 @@ inline bool IsDynamicTensor(const TfLiteTensor* tensor) {
void
*
Init
(
TfLiteContext
*
context
,
const
char
*
buffer
,
size_t
length
)
{
const
uint8_t
*
buffer_t
=
reinterpret_cast
<
const
uint8_t
*>
(
buffer
);
const
flexbuffers
::
Map
&
m
=
flexbuffers
::
GetRoot
(
buffer_t
,
length
).
AsMap
();
const
std
::
string
hashtype
=
m
[
"hashtype"
].
IsNull
()
?
kMurmurHash
:
m
[
"hashtype"
].
AsString
().
str
();
const
int
word_novelty_bits
=
m
[
"word_novelty_bits"
].
IsNull
()
?
0
:
m
[
"word_novelty_bits"
].
AsInt32
();
const
int
doc_size_levels
=
...
...
@@ -279,6 +308,28 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
m
[
"add_bos_tag"
].
IsNull
()
?
false
:
m
[
"add_bos_tag"
].
AsBool
();
const
bool
add_eos_tag
=
m
[
"add_eos_tag"
].
IsNull
()
?
true
:
m
[
"add_eos_tag"
].
AsBool
();
float
add_first_cap_feature
=
m
[
"add_first_cap_feature"
].
IsNull
()
?
0.0
f
:
m
[
"add_first_cap_feature"
].
AsFloat
();
float
add_all_caps_feature
=
m
[
"add_all_caps_feature"
].
IsNull
()
?
0.0
f
:
m
[
"add_all_caps_feature"
].
AsFloat
();
if
(
add_first_cap_feature
!=
0.0
f
&&
add_first_cap_feature
!=
1.0
f
)
{
context
->
ReportError
(
context
,
"add_first_cap_feature is %f, it should be 0.0 or 1.0., "
"resetting it to 1.0f
\n
"
,
add_first_cap_feature
);
add_first_cap_feature
=
1.0
f
;
}
if
(
add_all_caps_feature
!=
0.0
f
&&
add_all_caps_feature
!=
1.0
f
)
{
context
->
ReportError
(
context
,
"add_all_caps_feature is %f, it should be 0.0 or 1.0., "
"resetting it to 1.0f
\n
"
,
add_all_caps_feature
);
add_all_caps_feature
=
1.0
f
;
}
// Old models that use the op may not have this attribute set, for those
// models the default value of false will be used.
const
bool
exclude_nonalphaspace_unicodes
=
...
...
@@ -288,22 +339,35 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const
std
::
string
token_separators
=
m
[
"token_separators"
].
IsNull
()
?
""
:
m
[
"token_separators"
].
ToString
();
const
bool
normalize_repetition
=
m
[
"normalize_repetition"
].
AsBool
();
if
(
!
Hasher
::
SupportedHashType
(
hashtype
))
{
context
->
ReportError
(
context
,
"Unsupported hashtype %s
\n
"
,
hashtype
.
c_str
());
return
nullptr
;
}
return
new
ProjectionParams
(
m
[
"feature_size"
].
AsInt32
(),
m
[
"vocabulary"
].
AsString
().
str
(),
m
[
"feature_size"
].
AsInt32
(),
m
[
"vocabulary"
].
AsString
().
str
(),
hashtype
,
m
[
"max_splits"
].
AsInt32
(),
m
[
"split_on_space"
].
AsBool
(),
word_novelty_bits
,
doc_size_levels
,
add_bos_tag
?
BosTag
::
kGenerate
:
BosTag
::
kNone
,
add_eos_tag
?
EosTag
::
kGenerate
:
EosTag
::
kNone
,
exclude_nonalphaspace_unicodes
,
token_separators
,
normalize_repetition
);
exclude_nonalphaspace_unicodes
,
token_separators
,
normalize_repetition
,
add_first_cap_feature
==
1.0
f
,
add_all_caps_feature
==
1.0
f
);
}
void
*
InitV2
(
TfLiteContext
*
context
,
const
char
*
buffer
,
size_t
length
)
{
const
uint8_t
*
buffer_t
=
reinterpret_cast
<
const
uint8_t
*>
(
buffer
);
const
flexbuffers
::
Map
&
m
=
flexbuffers
::
GetRoot
(
buffer_t
,
length
).
AsMap
();
const
std
::
string
hashtype
=
m
[
"hashtype"
].
IsNull
()
?
kMurmurHash
:
m
[
"hashtype"
].
AsString
().
str
();
if
(
!
Hasher
::
SupportedHashType
(
hashtype
))
{
context
->
ReportError
(
context
,
"Unsupported hashtype %s
\n
"
,
hashtype
.
c_str
());
return
nullptr
;
}
return
new
ProjectionParamsV2
(
m
[
"feature_size"
].
AsInt32
(),
m
[
"vocabulary"
].
AsString
().
str
(),
m
[
"feature_size"
].
AsInt32
(),
m
[
"vocabulary"
].
AsString
().
str
(),
hashtype
,
m
[
"add_bos_tag"
].
AsBool
()
?
BosTag
::
kGenerate
:
BosTag
::
kNone
,
m
[
"add_eos_tag"
].
AsBool
()
?
EosTag
::
kGenerate
:
EosTag
::
kNone
,
m
[
"normalize_repetition"
].
AsBool
());
...
...
@@ -322,6 +386,8 @@ TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
constexpr
int
kHashCodeBits
=
64
;
constexpr
int
kMapBits
=
2
;
constexpr
int
kIncrement
=
kHashCodeBits
/
kMapBits
;
constexpr
int
kMapHigh
=
1
;
constexpr
int
kMapLow
=
2
;
template
<
typename
T
>
void
TypedEval
(
const
T
*
mapping_table
,
ProjectionParams
*
params
,
T
*
data
)
{
...
...
@@ -336,10 +402,12 @@ void TypedEval(const T* mapping_table, ProjectionParams* params, T* data) {
const
int
num_tokens
=
tokens
.
size
()
+
params
->
EosToken
();
for
(
int
j
=
-
params
->
BosToken
(),
offset0
=
0
;
j
<
num_tokens
;
++
j
)
{
std
::
string
word
;
bool
first_cap
,
all_caps
;
if
(
j
<
0
)
{
word
=
kBeginToken
;
}
else
if
(
j
<
tokens
.
size
())
{
word
=
params
->
LowerCaseUTF8WithSupportedUnicodes
(
tokens
[
j
]);
word
=
params
->
LowerCaseUTF8WithSupportedUnicodes
(
tokens
[
j
],
&
first_cap
,
&
all_caps
);
word
=
params
->
PreprocessToken
(
word
);
}
else
{
word
=
kEndToken
;
...
...
@@ -355,17 +423,29 @@ void TypedEval(const T* mapping_table, ProjectionParams* params, T* data) {
}
offset0
+=
params
->
FeatureSize
();
if
(
params
->
WordNoveltyEnabled
()
&&
!
hash_codes
.
empty
())
{
params
->
WordNoveltyFeature
(
&
data
[
offset0
-
1
],
params
->
WordNoveltyFeature
(
&
data
[
offset0
-
kWordNoveltyOffset
],
word_counter
[
hash_codes
[
0
]]
++
);
}
if
(
params
->
DocSizeFeatureEnabled
())
{
data
[
offset0
-
2
]
=
doc_size_feature
;
data
[
offset0
-
kDocSizeOffset
]
=
doc_size_feature
;
}
if
(
params
->
FirstCap
())
{
data
[
offset0
-
kFirstCapOffset
]
=
mapping_table
[
first_cap
?
kMapHigh
:
kMapLow
];
}
if
(
params
->
AllCaps
())
{
data
[
offset0
-
kAllCapsOffset
]
=
mapping_table
[
all_caps
?
kMapHigh
:
kMapLow
];
}
}
}
TfLiteStatus
Eval
(
TfLiteContext
*
context
,
TfLiteNode
*
node
)
{
auto
*
params
=
reinterpret_cast
<
ProjectionParams
*>
(
node
->
user_data
);
if
(
params
==
nullptr
)
{
context
->
ReportError
(
context
,
"Empty user data."
);
return
kTfLiteError
;
}
TF_LITE_ENSURE_OK
(
context
,
params
->
PreprocessInput
(
...
...
research/seq_flow_lite/tflite_ops/sequence_string_projection_test.cc
View file @
51f4ecad
...
...
@@ -16,13 +16,14 @@ limitations under the License.
#include <vector>
#include "tflite_ops/tf_tflite_diff_test_util.h" // seq_flow_lite
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
#include "tf_ops/projection_util.h" // seq_flow_lite
#include "tflite_ops/tf_tflite_diff_test_util.h" // seq_flow_lite
namespace
tflite
{
...
...
@@ -45,7 +46,8 @@ class SequenceStringProjectionModel : public SingleOpModel {
bool
split_on_space
,
int
max_splits
,
int
word_novelty_bits
,
int
doc_size_levels
,
bool
add_eos_tag
,
TensorType
output_type
,
const
std
::
string
&
token_separators
=
""
,
bool
normalize_repetition
=
false
)
{
bool
normalize_repetition
=
false
,
float
add_first_cap
=
0.0
,
float
add_all_caps
=
0.0
,
const
string
&
hashtype
=
kMurmurHash
)
{
flexbuffers
::
Builder
fbb
;
fbb
.
Map
([
&
]
{
fbb
.
Int
(
"feature_size"
,
4
);
...
...
@@ -56,7 +58,10 @@ class SequenceStringProjectionModel : public SingleOpModel {
fbb
.
Bool
(
"split_on_space"
,
split_on_space
);
fbb
.
Bool
(
"add_eos_tag"
,
add_eos_tag
);
fbb
.
String
(
"token_separators"
,
token_separators
);
fbb
.
String
(
"hashtype"
,
hashtype
);
fbb
.
Bool
(
"normalize_repetition"
,
normalize_repetition
);
fbb
.
Float
(
"add_first_cap_feature"
,
add_first_cap
);
fbb
.
Float
(
"add_all_caps_feature"
,
add_all_caps
);
});
fbb
.
Finish
();
output_
=
AddOutput
({
output_type
,
{}});
...
...
@@ -74,7 +79,7 @@ class SequenceStringProjectionModel : public SingleOpModel {
PopulateStringTensor
(
input_
,
{
input
});
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
return
in
terpreter_
->
Invoke
();
return
S
in
gleOpModel
::
InvokeUnchecked
();
}
template
<
typename
T
>
...
...
@@ -91,6 +96,12 @@ class SequenceStringProjectionModel : public SingleOpModel {
int
output_
;
};
TEST
(
SequenceStringProjectionTest
,
IncorrectHashtype
)
{
SequenceStringProjectionModel
m
(
true
,
-
1
,
0
,
0
,
true
,
TensorType_UINT8
,
""
,
false
,
0.0
,
0.0
,
"unsupported"
);
EXPECT_EQ
(
m
.
InvokeFailable
(
" "
),
kTfLiteError
);
}
TEST
(
SequenceStringProjectionTest
,
RegularInputUint8
)
{
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
vector
<
uint8_t
>>>
testcase
=
{
{
"hello"
,
{
127
,
255
,
255
,
127
,
127
,
255
,
127
,
127
}},
...
...
@@ -273,6 +284,39 @@ TEST(SequenceStringProjectionTest, EmptyInput) {
EXPECT_EQ
(
with_eos
.
InvokeFailable
(
"hello"
),
kTfLiteOk
);
}
TEST
(
SequenceStringProjectionTest
,
FirstCap
)
{
SequenceStringProjectionModel
op
(
/*split_on_space=*/
true
,
/*max_splits=*/
-
1
,
/*word_novelty_bits=*/
0
,
/*doc_size_levels=*/
0
,
/*add_eos_tag=*/
false
,
/*output_type=*/
TensorType_UINT8
,
/*token_separators=*/
" "
,
/*normalize_repetition=*/
false
,
/*add_first_cap=*/
0.5
);
op
.
Invoke
(
"hello"
);
auto
output1
=
op
.
GetOutput
<
uint8_t
>
();
op
.
Invoke
(
"Hello"
);
auto
output2
=
op
.
GetOutput
<
uint8_t
>
();
EXPECT_NE
(
output1
[
1
],
output2
[
1
]);
}
TEST
(
SequenceStringProjectionTest
,
AllCaps
)
{
SequenceStringProjectionModel
op
(
/*split_on_space=*/
true
,
/*max_splits=*/
-
1
,
/*word_novelty_bits=*/
0
,
/*doc_size_levels=*/
0
,
/*add_eos_tag=*/
false
,
/*output_type=*/
TensorType_UINT8
,
/*token_separators=*/
" "
,
/*normalize_repetition=*/
false
,
/*add_first_cap=*/
0.0
,
/*add_all_caps=*/
0.5
);
op
.
Invoke
(
"hello"
);
auto
output1
=
op
.
GetOutput
<
uint8_t
>
();
op
.
Invoke
(
"HELLO"
);
auto
output2
=
op
.
GetOutput
<
uint8_t
>
();
EXPECT_NE
(
output1
[
0
],
output2
[
0
]);
}
TEST
(
SequenceStringProjectionTest
,
NormalizeRepetition
)
{
// Normalize the repeated special tokens. Used for the emotion models.
SequenceStringProjectionModel
m1
(
true
,
-
1
,
0
,
0
,
true
,
TensorType_UINT8
,
""
,
...
...
@@ -691,6 +735,62 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
test_cases
.
push_back
(
test_case
);
}
{
OpEquivTestCase
test_case
;
test_case
.
test_name
=
"CapBaseline"
;
test_case
.
attributes
[
"vocabulary"
]
=
AttrValue
(
""
);
test_case
.
attributes
[
"split_on_space"
]
=
AttrValue
(
true
);
test_case
.
attributes
[
"feature_size"
]
=
AttrValue
(
8
);
test_case
.
attributes
[
"add_eos_tag"
]
=
AttrValue
(
false
);
test_case
.
attributes
[
"add_bos_tag"
]
=
AttrValue
(
false
);
test_case
.
input_tensors
.
push_back
(
StringTensor
({
1
},
{
"Hello hello HELLO"
}));
test_case
.
output_tensors
.
emplace_back
(
FloatTensor
({},
{}),
kScale
,
kZero
);
test_cases
.
push_back
(
test_case
);
}
{
OpEquivTestCase
test_case
;
test_case
.
test_name
=
"FirstCap"
;
test_case
.
attributes
[
"vocabulary"
]
=
AttrValue
(
""
);
test_case
.
attributes
[
"split_on_space"
]
=
AttrValue
(
true
);
test_case
.
attributes
[
"feature_size"
]
=
AttrValue
(
8
);
test_case
.
attributes
[
"add_eos_tag"
]
=
AttrValue
(
false
);
test_case
.
attributes
[
"add_bos_tag"
]
=
AttrValue
(
false
);
test_case
.
attributes
[
"add_first_cap_feature"
]
=
AttrValue
(
1.0
);
test_case
.
input_tensors
.
push_back
(
StringTensor
({
1
},
{
"Hello hello HELLO"
}));
test_case
.
output_tensors
.
emplace_back
(
FloatTensor
({},
{}),
kScale
,
kZero
);
test_cases
.
push_back
(
test_case
);
}
{
OpEquivTestCase
test_case
;
test_case
.
test_name
=
"AllCaps"
;
test_case
.
attributes
[
"vocabulary"
]
=
AttrValue
(
""
);
test_case
.
attributes
[
"split_on_space"
]
=
AttrValue
(
true
);
test_case
.
attributes
[
"feature_size"
]
=
AttrValue
(
8
);
test_case
.
attributes
[
"add_eos_tag"
]
=
AttrValue
(
false
);
test_case
.
attributes
[
"add_bos_tag"
]
=
AttrValue
(
false
);
test_case
.
attributes
[
"add_all_caps_feature"
]
=
AttrValue
(
1.0
);
test_case
.
input_tensors
.
push_back
(
StringTensor
({
1
},
{
"Hello hello HELLO"
}));
test_case
.
output_tensors
.
emplace_back
(
FloatTensor
({},
{}),
kScale
,
kZero
);
test_cases
.
push_back
(
test_case
);
}
{
OpEquivTestCase
test_case
;
test_case
.
test_name
=
"FirstCapAllCaps"
;
test_case
.
attributes
[
"vocabulary"
]
=
AttrValue
(
""
);
test_case
.
attributes
[
"split_on_space"
]
=
AttrValue
(
true
);
test_case
.
attributes
[
"feature_size"
]
=
AttrValue
(
8
);
test_case
.
attributes
[
"add_eos_tag"
]
=
AttrValue
(
false
);
test_case
.
attributes
[
"add_bos_tag"
]
=
AttrValue
(
false
);
test_case
.
attributes
[
"add_first_cap_feature"
]
=
AttrValue
(
1.0
);
test_case
.
attributes
[
"add_all_caps_feature"
]
=
AttrValue
(
1.0
);
test_case
.
input_tensors
.
push_back
(
StringTensor
({
1
},
{
"Hello hello HELLO"
}));
test_case
.
output_tensors
.
emplace_back
(
FloatTensor
({},
{}),
kScale
,
kZero
);
test_cases
.
push_back
(
test_case
);
}
return
test_cases
;
}
...
...
@@ -701,9 +801,13 @@ INSTANTIATE_TEST_SUITE_P(
class
SequenceStringProjectionV2Model
:
public
SingleOpModel
{
public:
explicit
SequenceStringProjectionV2Model
(
std
::
vector
<
std
::
vector
<
int
>>
input_shapes
)
{
std
::
vector
<
std
::
vector
<
int
>>
input_shapes
,
const
string
&
hashtype
=
kMurmurHash
)
{
flexbuffers
::
Builder
fbb
;
fbb
.
Map
([
&
]
{
fbb
.
Int
(
"feature_size"
,
4
);
});
fbb
.
Map
([
&
]
{
fbb
.
Int
(
"feature_size"
,
4
);
fbb
.
String
(
"hashtype"
,
hashtype
);
});
fbb
.
Finish
();
input_
=
AddInput
(
TensorType_STRING
);
output_
=
AddOutput
({
TensorType_UINT8
,
{}});
...
...
@@ -715,7 +819,13 @@ class SequenceStringProjectionV2Model : public SingleOpModel {
PopulateStringTensor
(
input_
,
input
);
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
ASSERT_EQ
(
interpreter_
->
Invoke
(),
expected
);
ASSERT_EQ
(
SingleOpModel
::
InvokeUnchecked
(),
expected
);
}
TfLiteStatus
InvokeFailable
(
const
std
::
string
&
input
)
{
PopulateStringTensor
(
input_
,
{
input
});
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
return
SingleOpModel
::
InvokeUnchecked
();
}
private:
...
...
@@ -723,6 +833,11 @@ class SequenceStringProjectionV2Model : public SingleOpModel {
int
output_
;
};
TEST
(
SequenceStringProjectionV2Test
,
IncorrectHashtype
)
{
SequenceStringProjectionV2Model
m
({{
1
,
0
}},
"unsupported"
);
EXPECT_EQ
(
m
.
InvokeFailable
(
" "
),
kTfLiteError
);
}
TEST
(
SequenceStringProjectionV2Test
,
RegularInputUint8EmptyNotSupported
)
{
// TFLite test infratructure currently does not let the error message to be
// extracted on failure. As a result just the return error code is tested
...
...
research/seq_flow_lite/trainer.py
View file @
51f4ecad
...
...
@@ -52,19 +52,22 @@ def create_model(model, model_config, features, mode):
"""Creates a sequence labeling model."""
keras_model
=
model
.
Encoder
(
model_config
,
mode
)
logits
=
keras_model
(
features
[
"projection"
],
features
[
"seq_length"
])
if
not
model_config
[
"multilabel"
]:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
features
[
"label"
],
logits
=
logits
)
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
if
not
model_config
[
"multilabel"
]:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
features
[
"label"
],
logits
=
logits
)
else
:
loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
features
[
"label"
],
logits
=
logits
)
loss
=
tf
.
reduce_mean
(
loss
)
loss
+=
tf
.
add_n
(
keras_model
.
losses
)
else
:
loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
features
[
"label"
],
logits
=
logits
)
loss
=
tf
.
reduce_mean
(
loss
)
loss
+=
tf
.
add_n
(
keras_model
.
losses
)
loss
=
None
return
(
loss
,
logits
)
def
create_optimizer
(
loss
,
runner_config
):
def
create_optimizer
(
loss
,
runner_config
,
params
):
"""Returns a train_op using Adam optimizer."""
learning_rate
=
tf
.
train
.
exponential_decay
(
learning_rate
=
runner_config
[
"learning_rate"
],
...
...
@@ -73,7 +76,7 @@ def create_optimizer(loss, runner_config):
decay_rate
=
runner_config
[
"learning_rate_decay_rate"
],
staircase
=
True
)
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
learning_rate
)
if
FLAGS
.
use_tpu
:
if
params
[
"
use_tpu
"
]
:
optimizer
=
tf
.
tpu
.
CrossShardOptimizer
(
optimizer
)
return
optimizer
.
minimize
(
loss
,
global_step
=
tf
.
train
.
get_global_step
())
...
...
@@ -87,7 +90,6 @@ def model_fn_builder(runner_config):
def
model_fn
(
features
,
mode
,
params
):
"""The `model_fn` for TPUEstimator."""
del
params
label_ids
=
None
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
label_ids
=
features
[
"label"
]
...
...
@@ -96,7 +98,7 @@ def model_fn_builder(runner_config):
loss
,
logits
=
create_model
(
model
,
model_config
,
features
,
mode
)
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
train_op
=
create_optimizer
(
loss
,
runner_config
)
train_op
=
create_optimizer
(
loss
,
runner_config
,
params
)
return
tf
.
compat
.
v1
.
estimator
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
train_op
=
train_op
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
...
...
@@ -108,8 +110,16 @@ def model_fn_builder(runner_config):
eval_metrics
=
(
metric_fn
,
[
loss
,
label_ids
,
logits
])
return
tf
.
compat
.
v1
.
estimator
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
eval_metrics
=
eval_metrics
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
predictions
=
{
"logits"
:
logits
}
if
not
runner_config
[
"model_config"
][
"multilabel"
]:
predictions
[
"predictions"
]
=
tf
.
nn
.
softmax
(
logits
)
else
:
predictions
[
"predictions"
]
=
tf
.
math
.
sigmoid
(
logits
)
return
tf
.
compat
.
v1
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
)
else
:
assert
False
,
"Expected to be called in TRAIN
or
EVAL mode."
assert
False
,
"Expected to be called in TRAIN
,
EVAL
, or PREDICT
mode."
return
model_fn
...
...
research/seq_flow_lite/utils/tflite_utils.py
View file @
51f4ecad
...
...
@@ -42,13 +42,20 @@ def set_output_quantized_for_custom_ops(graph_def):
'SequenceStringProjectionV2'
:
[
tf
.
float32
.
as_datatype_enum
],
'PoolingOp'
:
[
tf
.
float32
.
as_datatype_enum
],
'ExpectedValueOp'
:
[
tf
.
float32
.
as_datatype_enum
],
'LayerNormV2'
:
[
tf
.
float32
.
as_datatype_enum
],
'LayerNorm'
:
[
tf
.
float32
.
as_datatype_enum
],
'UniformCausalAttn'
:
[
tf
.
float32
.
as_datatype_enum
],
}
custom_op_renames
=
{
'SequenceStringProjection'
:
'SEQUENCE_STRING_PROJECTION'
,
'SequenceStringProjectionV2'
:
'SEQUENCE_STRING_PROJECTION_V2'
,
}
for
node
in
graph_def
.
node
:
if
node
.
op
in
quantized_custom_ops
:
node
.
attr
[
'_output_quantized'
].
b
=
True
node
.
attr
[
'_output_types'
].
list
.
type
[:]
=
quantized_custom_ops
[
node
.
op
]
if
node
.
op
in
custom_op_renames
:
node
.
op
=
custom_op_renames
[
node
.
op
]
def
generate_tflite
(
session
,
graph
,
input_tensors
,
output_tensors
):
...
...
@@ -59,16 +66,16 @@ def generate_tflite(session, graph, input_tensors, output_tensors):
set_output_quantized_for_custom_ops
(
graph_def
)
# TODO(b/171063452): Bug needs to be fixed to handle this correctly.
# def _node_name(tensor):
# return tensor.name.split(':')[0]
# TODO(b/171063452): Bug needs to be fixed to handle this correctly.
# def _node_name(tensor):
# return tensor.name.split(':')[0]
# input_arrays_with_shape = [
# (_node_name(tensor), None) for tensor in input_tensors
# ]
# output_arrays = [_node_name(tensor) for tensor in output_tensors]
# converter = tf.lite.TFLiteConverter(graph_def, None, None,
#
input_arrays_with_shape, output_arrays)
# input_arrays_with_shape = [
# (_node_name(tensor), None) for tensor in input_tensors
# ]
# output_arrays = [_node_name(tensor) for tensor in output_tensors]
# converter = tf.lite.TFLiteConverter(graph_def, None, None,
#
input_arrays_with_shape, output_arrays)
converter
=
tf
.
lite
.
TFLiteConverter
(
graph_def
,
input_tensors
,
output_tensors
)
converter
.
inference_type
=
tf
.
uint8
converter
.
default_ranges_stats
=
(
127.5
,
127.5
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment