Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fec0338f
"csrc/vscode:/vscode.git/clone" did not exist on "b6dfb9ee3892dc9d9b36c8bf9a925a7c6358d331"
Unverified
Commit
fec0338f
authored
Sep 14, 2021
by
pyoung2778
Committed by
GitHub
Sep 14, 2021
Browse files
Checkin seq_flow_lite (#10219)
parent
c6d7d57d
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
434 additions
and
146 deletions
+434
-146
research/seq_flow_lite/tf_ops/sequence_string_projection_op_v2.cc
.../seq_flow_lite/tf_ops/sequence_string_projection_op_v2.cc
+15
-1
research/seq_flow_lite/tf_ops/text_distorter.h
research/seq_flow_lite/tf_ops/text_distorter.h
+0
-1
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
+0
-54
research/seq_flow_lite/tflite_ops/BUILD
research/seq_flow_lite/tflite_ops/BUILD
+14
-0
research/seq_flow_lite/tflite_ops/expected_value.cc
research/seq_flow_lite/tflite_ops/expected_value.cc
+2
-2
research/seq_flow_lite/tflite_ops/expected_value.h
research/seq_flow_lite/tflite_ops/expected_value.h
+2
-2
research/seq_flow_lite/tflite_ops/layer_norm.cc
research/seq_flow_lite/tflite_ops/layer_norm.cc
+100
-4
research/seq_flow_lite/tflite_ops/layer_norm.h
research/seq_flow_lite/tflite_ops/layer_norm.h
+2
-2
research/seq_flow_lite/tflite_ops/layer_norm_test.cc
research/seq_flow_lite/tflite_ops/layer_norm_test.cc
+37
-35
research/seq_flow_lite/tflite_ops/quantization_util.h
research/seq_flow_lite/tflite_ops/quantization_util.h
+2
-2
research/seq_flow_lite/tflite_ops/registerer.cc
research/seq_flow_lite/tflite_ops/registerer.cc
+49
-0
research/seq_flow_lite/tflite_ops/sequence_string_projection.cc
...ch/seq_flow_lite/tflite_ops/sequence_string_projection.cc
+6
-4
research/seq_flow_lite/tflite_ops/sequence_string_projection.h
...rch/seq_flow_lite/tflite_ops/sequence_string_projection.h
+2
-2
research/seq_flow_lite/tflite_ops/sequence_string_projection_test.cc
...q_flow_lite/tflite_ops/sequence_string_projection_test.cc
+27
-14
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.cc
...arch/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.cc
+7
-2
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.h
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.h
+2
-2
research/seq_flow_lite/trainer_v2.py
research/seq_flow_lite/trainer_v2.py
+113
-0
research/seq_flow_lite/utils/tflite_utils.py
research/seq_flow_lite/utils/tflite_utils.py
+54
-19
No files found.
research/seq_flow_lite/tf_ops/sequence_string_projection_op_v2.cc
View file @
fec0338f
...
...
@@ -40,6 +40,8 @@ constexpr char kEndTokenTSP[] = "<EOS>";
constexpr
float
kMappingTable
[
4
]
=
{
0
,
1
,
-
1
,
0
};
constexpr
int
kIncrement
=
32
;
// Version 2 OpKernel for the sequence string projection op.
// Template T can be int32 or int64.
template
<
typename
T
>
class
SequenceStringProjectionOpV2
:
public
OpKernel
{
public:
...
...
@@ -136,7 +138,7 @@ class SequenceStringProjectionOpV2 : public OpKernel {
}
else
{
word
=
kEndTokenTSP
;
}
hasher_
->
GetHashCodes
(
word
,
&
hash_codes
);
hasher_
->
GetHashCodes
(
word
,
hash_codes
);
for
(
int
hindex
=
0
,
k
=
0
;
hindex
<
hash_codes
.
size
();
hindex
++
)
{
auto
hash
=
hash_codes
[
hindex
];
for
(
int
kmax
=
std
::
min
(
k
+
kIncrement
,
feature_size_
);
k
<
kmax
;)
{
...
...
@@ -153,13 +155,25 @@ class SequenceStringProjectionOpV2 : public OpKernel {
}
private:
// Dimensionality of the ternary vector for each token in the text.
int32
feature_size_
;
// An object used to hash tokens in the text.
std
::
unique_ptr
<
Hasher
>
hasher_
;
// An object used for distorting text before projection.
std
::
unique_ptr
<
TextDistorter
>
text_distorter_
;
// An object used for manipulating unicode in the text. It performs tasks such
// as retaining only whitelisted unicodes in the text tokens and lowercasing
// them.
std
::
unique_ptr
<
ProjectionUnicodeHandler
>
unicode_handler_
;
// An object used for normalizing tokens in the text. This performs tasks
// such as identifying repeated characters and replace them with a single
// instance.
std
::
unique_ptr
<
ProjectionNormalizer
>
projection_normalizer_
;
// Character whitelist used by the projection operator.
std
::
string
vocabulary_
;
// When true include an end of sentence token in the projection.
int
eos_tag_
;
// When true include a begin of sentence token in the projection.
int
bos_tag_
;
};
...
...
research/seq_flow_lite/tf_ops/text_distorter.h
View file @
fec0338f
...
...
@@ -32,7 +32,6 @@ class TextDistorter {
assert
(
distortion_probability_
<=
1.0
);
}
std
::
string
DistortText
(
icu
::
UnicodeString
*
uword
);
bool
BernouilleSample
(
float
p
)
{
return
(
generator_
.
RandFloat
()
<=
p
);
}
private:
tensorflow
::
random
::
PhiloxRandom
philox_
;
...
...
research/seq_flow_lite/tf_ops/tf_custom_ops.cc
View file @
fec0338f
...
...
@@ -20,30 +20,6 @@ limitations under the License.
using
::
tensorflow
::
int32
;
class
PoolingOp
:
public
tensorflow
::
OpKernel
{
public:
explicit
PoolingOp
(
tensorflow
::
OpKernelConstruction
*
context
)
:
tensorflow
::
OpKernel
(
context
)
{}
void
Compute
(
tensorflow
::
OpKernelContext
*
ctx
)
override
{}
};
REGISTER_KERNEL_BUILDER
(
Name
(
"PoolingOp"
).
Device
(
::
tensorflow
::
DEVICE_CPU
),
PoolingOp
);
REGISTER_OP
(
"PoolingOp"
)
.
Input
(
"multiplier: float32"
)
.
Input
(
"constant: float32"
)
.
Input
(
"forward: float32"
)
.
Output
(
"state: float32"
)
.
SetShapeFn
([](
::
tensorflow
::
shape_inference
::
InferenceContext
*
c
)
{
c
->
set_output
(
0
,
c
->
input
(
0
));
return
tensorflow
::
Status
::
OK
();
})
.
Doc
(
R"doc(
Dummy pooling op.
)doc"
);
class
ExpectedValueOp
:
public
tensorflow
::
OpKernel
{
public:
explicit
ExpectedValueOp
(
tensorflow
::
OpKernelConstruction
*
context
)
...
...
@@ -93,33 +69,3 @@ REGISTER_OP("LayerNorm")
.
Doc
(
R"doc(
Dummy layer norm op.
)doc"
);
class
UniformCausalAttnOp
:
public
tensorflow
::
OpKernel
{
public:
explicit
UniformCausalAttnOp
(
tensorflow
::
OpKernelConstruction
*
context
)
:
tensorflow
::
OpKernel
(
context
)
{}
void
Compute
(
tensorflow
::
OpKernelContext
*
ctx
)
override
{}
};
REGISTER_KERNEL_BUILDER
(
Name
(
"UniformCausalAttn"
).
Device
(
::
tensorflow
::
DEVICE_CPU
),
UniformCausalAttnOp
);
REGISTER_OP
(
"UniformCausalAttn"
)
.
Input
(
"input: float32"
)
.
Input
(
"time_step: int32"
)
.
Input
(
"selected_beams: int32"
)
.
Attr
(
"feature_size: int"
)
.
Attr
(
"beam_size: int"
)
.
Output
(
"output: float32"
)
.
SetShapeFn
([](
::
tensorflow
::
shape_inference
::
InferenceContext
*
c
)
{
auto
batch_size
=
c
->
Dim
(
c
->
input
(
0
),
0
);
int32
feature_size
;
TF_RETURN_IF_ERROR
(
c
->
GetAttr
(
"feature_size"
,
&
feature_size
));
c
->
set_output
(
0
,
c
->
MakeShape
({
batch_size
,
1
,
feature_size
}));
return
tensorflow
::
Status
::
OK
();
})
.
Doc
(
R"doc(
Dummy uniform causal attn op.
)doc"
);
research/seq_flow_lite/tflite_ops/BUILD
View file @
fec0338f
# TFLite ops for sequence string projection.
load
(
"@org_tensorflow//tensorflow:tensorflow.bzl"
,
"pybind_extension"
)
load
(
"@org_tensorflow//tensorflow/lite:build_def.bzl"
,
"tflite_copts"
)
licenses
([
"notice"
])
...
...
@@ -100,3 +101,16 @@ cc_test(
"@flatbuffers"
,
],
)
pybind_extension
(
name
=
"registerer"
,
srcs
=
[
"registerer.cc"
],
module_name
=
"registerer"
,
deps
=
[
":expected_value"
,
":layer_norm"
,
":sequence_string_projection"
,
"@org_tensorflow//tensorflow/lite:framework"
,
"@pybind11"
,
],
)
research/seq_flow_lite/tflite_ops/expected_value.cc
View file @
fec0338f
...
...
@@ -18,7 +18,7 @@ limitations under the License.
#include "tflite_ops/quantization_util.h" // seq_flow_lite
namespace
tf
lite
{
namespace
seq_flow_
lite
{
namespace
ops
{
namespace
custom
{
...
...
@@ -156,4 +156,4 @@ TfLiteRegistration* Register_EXPECTED_VALUE() {
}
// namespace custom
}
// namespace ops
}
// namespace
tf
lite
}
// namespace
seq_flow_
lite
research/seq_flow_lite/tflite_ops/expected_value.h
View file @
fec0338f
...
...
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/register.h"
namespace
tf
lite
{
namespace
seq_flow_
lite
{
namespace
ops
{
namespace
custom
{
...
...
@@ -25,6 +25,6 @@ TfLiteRegistration* Register_EXPECTED_VALUE();
}
// namespace custom
}
// namespace ops
}
// namespace
tf
lite
}
// namespace
seq_flow_
lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_EXPECTED_VALUE_H_
research/seq_flow_lite/tflite_ops/layer_norm.cc
View file @
fec0338f
...
...
@@ -17,10 +17,10 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
namespace
tf
lite
{
namespace
seq_flow_
lite
{
namespace
ops
{
namespace
custom
{
...
...
@@ -213,6 +213,102 @@ TfLiteStatus FlexibleLayerNorm(const TfLiteTensor* input, const float scale,
return
kTfLiteOk
;
}
/*
* Layer normalization is optimized as follows in integer arithmetic
*
* Algorithm
* *********
* Subscript i \in {1, ..., N}, Inputs q_i, Outputs oq_i.
*
* x_i = (q_i - input_zero_point) * input_scale
* mean = sum_i x_i / N
* var = sum_i (x_i * x_i / N) - mean * mean
* std = sqrt(var + tolerance)
* xni = (xi - mean) / std
* yi = xni * scale + offset
* o_i = round(y_i / output_scale + output_zero_point)
* oq_i = clamp(o_i, 0, 255)
*
* Optimizations
* *************
* Applying linear expansion
* x_i = q_i * input_scale - input_zero_point * input_scale
* or x_i = m * qi + c
* mean = m * mean_q + c
* Variance is not affected by a constant shift to input
* var = m^2 * var_q
* std = m * sqrt(var_q + tolerance)
* Expanding xi, mean, std in equation for xni
* xni = (m * qi + c - m * mean_q - c) / m * sqrt(var_q + tolerance)
* Simplifying
* xni = (qi - mean_q) / sqrt(var_q + tolerance)
* Setting inv_std_qi = 1 / sqrt(var_q + tolerance)
* xni = qi * inv_std_qi - mean_q * inv_std_qi
* yi = qi * inv_std_qi * scale - mean_q * inv_std_qi * scale + offset
* o_i = round(qi * inv_std_qi * scale / output_scale
* - mean_q * inv_std_qi * scale / output_scale
* + offset / output_scale
* + output_zero_point)
* Setting
* static_bias = offset / output_scale + output_zero_point
* static_scale = scale / output_scale
* o_i = round(qi * inv_std_qi * static_scale
* - mean_q * inv_std_qi * static_scale
* + static_bias)
* Setting
* dynamic_scale = inv_std_qi * static_scale
* dynamic_bias = static_bias - mean_q * dynamic_scale
* o_i = round(qi * dynamic_scale + dynamic_bias)
* oq_i = clamp(round(qi * dynamic_scale + dynamic_bias), 0, 255)
*
* This results in the below optimized implementation. The strategy is to first
* compute first and second order summary statistics for qi in a loop,
* then compute mean_q, var_q and then dynamic_scale/dynamic_bias. This
* allows one to compute oqi quickly in a tight loop.
* */
TfLiteStatus
IntegerLayerNorm
(
const
TfLiteTensor
*
input
,
const
float
scale
,
const
float
offset
,
TfLiteTensor
*
output
)
{
const
int
input_rank
=
input
->
dims
->
size
;
const
int
num_features
=
input
->
dims
->
data
[
input_rank
-
1
];
const
int
time_steps
=
static_cast
<
int
>
(
GetNumberOfSteps
(
input
)
/
num_features
);
const
float
out_inverse_scale
=
1.0
f
/
output
->
params
.
scale
;
const
float
static_scale
=
scale
*
out_inverse_scale
;
const
float
static_bias
=
static_cast
<
float
>
(
output
->
params
.
zero_point
)
+
offset
*
out_inverse_scale
;
const
float
inverse_num_features
=
1.0
f
/
num_features
;
const
uint8_t
*
const
in_ptr
=
input
->
data
.
uint8
;
uint8_t
*
out_ptr
=
output
->
data
.
uint8
;
for
(
int
i
=
0
;
i
<
time_steps
;
++
i
)
{
int32_t
i32_sum_q
=
0
;
int32_t
i32_sum_qq
=
0
;
const
int32_t
index
=
i
*
num_features
;
for
(
int
j
=
index
;
j
<
index
+
num_features
;
++
j
)
{
const
int32_t
q_i
=
static_cast
<
int32_t
>
(
in_ptr
[
j
]);
// Compute first and second order statistics for qi.
i32_sum_q
+=
q_i
;
i32_sum_qq
+=
q_i
*
q_i
;
}
const
float
second_moment_qq
=
i32_sum_qq
*
inverse_num_features
;
const
float
mean_q
=
i32_sum_q
*
inverse_num_features
;
const
float
var_q
=
second_moment_qq
-
mean_q
*
mean_q
;
const
float
inv_std_q
=
1.0
f
/
sqrt
(
var_q
+
1e-6
);
const
float
dynamic_scale
=
inv_std_q
*
static_scale
;
const
float
dynamic_bias
=
static_bias
-
mean_q
*
dynamic_scale
;
for
(
int
j
=
index
;
j
<
index
+
num_features
;
++
j
)
{
const
int32_t
invalue
=
static_cast
<
int32_t
>
(
in_ptr
[
j
]);
const
float
value
=
invalue
*
dynamic_scale
+
dynamic_bias
;
// Use an offseted cast to perform float round.
const
int32_t
i32value
=
static_cast
<
int32_t
>
(
value
+
((
value
>=
0.0
)
?
0.5
f
:
-
0.5
f
));
// Clamp the result.
out_ptr
[
j
]
=
static_cast
<
uint8_t
>
(
std
::
max
(
std
::
min
(
255
,
i32value
),
0
));
}
}
return
kTfLiteOk
;
}
TfLiteStatus
DefaultLayerNormFloat
(
const
TfLiteTensor
*
input
,
const
float
scale
,
const
float
offset
,
TfLiteTensor
*
output
)
{
const
int
input_rank
=
input
->
dims
->
size
;
...
...
@@ -298,7 +394,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if
(
num_axis
==
1
&&
(
axis
->
data
.
i32
[
0
]
==
-
1
||
axis
->
data
.
i32
[
0
]
==
(
input
->
dims
->
size
-
1
)))
{
if
(
input
->
type
==
kTfLiteUInt8
)
{
return
Default
LayerNorm
(
input
,
scale
,
offset
,
output
);
return
Integer
LayerNorm
(
input
,
scale
,
offset
,
output
);
}
else
if
(
input
->
type
==
kTfLiteFloat32
)
{
return
DefaultLayerNormFloat
(
input
,
scale
,
offset
,
output
);
}
else
{
...
...
@@ -328,4 +424,4 @@ TfLiteRegistration* Register_LAYER_NORM() {
}
// namespace custom
}
// namespace ops
}
// namespace
tf
lite
}
// namespace
seq_flow_
lite
research/seq_flow_lite/tflite_ops/layer_norm.h
View file @
fec0338f
...
...
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/register.h"
namespace
tf
lite
{
namespace
seq_flow_
lite
{
namespace
ops
{
namespace
custom
{
...
...
@@ -25,6 +25,6 @@ TfLiteRegistration* Register_LAYER_NORM();
}
// namespace custom
}
// namespace ops
}
// namespace
tf
lite
}
// namespace
seq_flow_
lite
#endif // LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
research/seq_flow_lite/tflite_ops/layer_norm_test.cc
View file @
fec0338f
...
...
@@ -20,40 +20,35 @@ limitations under the License.
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "tensorflow/lite/kernels/test_util.h"
namespace
tf
lite
{
namespace
seq_flow_
lite
{
namespace
ops
{
namespace
custom
{
namespace
{
class
LayerNormModel
:
public
SingleOpModel
{
using
::
testing
::
ElementsAreArray
;
using
::
tflite
::
ArrayFloatNear
;
using
::
tflite
::
Dequantize
;
using
::
tflite
::
TensorType_INT32
;
using
::
tflite
::
TensorType_UINT8
;
class
LayerNormModel
:
public
::
tflite
::
SingleOpModel
{
public:
explicit
LayerNormModel
(
const
TensorData
&
input
,
float
output_min
,
explicit
LayerNormModel
(
std
::
initializer_list
<
int
>
input_shape
,
float
input_min
,
float
input_max
,
float
output_min
,
float
output_max
,
float
scale
,
float
offset
,
std
::
initializer_list
<
int
>
axis_shape
,
std
::
initializer_list
<
int
>
axis
)
std
::
initializer_list
<
int
>
axes
)
:
scale_value_
(
scale
),
offset_value_
(
offset
)
{
input_
=
AddInput
(
input
);
const
int
num_axes
=
axes
.
size
();
input_
=
AddInput
({
TensorType_UINT8
,
input_shape
,
input_min
,
input_max
});
scale_
=
AddInput
(
{
TensorType_UINT8
,
{
1
},
std
::
min
(
scale
,
0.0
f
),
std
::
max
(
scale
,
0.0
f
)});
offset_
=
AddInput
({
TensorType_UINT8
,
{
1
},
std
::
min
(
offset
,
0.0
f
),
std
::
max
(
offset
,
0.0
f
)});
axis_
=
AddConstInput
(
TensorType_INT32
,
ax
i
s
,
axis_shape
);
axis_
=
AddConstInput
(
TensorType_INT32
,
ax
e
s
,
{
num_axes
}
);
output_
=
AddOutput
({
TensorType_UINT8
,
{},
output_min
,
output_max
});
flexbuffers
::
Builder
fbb
;
fbb
.
Map
([
&
]
{
{
size_t
start
=
fbb
.
StartVector
(
"axes"
);
for
(
const
int
&
aval
:
axis
)
{
fbb
.
Int
(
aval
);
}
fbb
.
EndVector
(
start
,
/*typed=*/
true
,
/*fixed=*/
false
);
}
});
fbb
.
Finish
();
SetCustomOp
(
"LayerNorm"
,
fbb
.
GetBuffer
(),
Register_LAYER_NORM
);
SetCustomOp
(
"LayerNorm"
,
{},
Register_LAYER_NORM
);
BuildInterpreter
({
GetShape
(
input_
)});
}
...
...
@@ -88,8 +83,9 @@ TEST(LayerNormModelTest, RegularInput) {
const
std
::
vector
<
float
>
expected_output
=
{
0.0
,
-
1.6
,
0.53
,
1.07
,
0.0
,
-
1.13
,
1.59
,
-
0.45
};
LayerNormModel
m
({
TensorType_UINT8
,
{
1
,
2
,
4
},
-
10
,
10
},
-
10
,
10
,
1.0
,
0.0
,
{
1
},
{
2
});
LayerNormModel
m
(
/*input_shape=*/
{
1
,
2
,
4
},
/*input_min=*/
-
10
,
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
Invoke
();
EXPECT_THAT
(
...
...
@@ -106,8 +102,9 @@ TEST(LayerNormModelTest, NegativeScale) {
// Standard deviation values are 3.74, 4.41
const
std
::
vector
<
float
>
expected_output
=
{
0.0
,
1.6
,
-
0.53
,
-
1.07
,
0.0
,
1.13
,
-
1.59
,
0.45
};
LayerNormModel
m
({
TensorType_UINT8
,
{
1
,
2
,
4
},
-
10
,
10
},
-
10
,
10
,
-
1.0
,
0.0
,
{
1
},
{
2
});
LayerNormModel
m
(
/*input_shape=*/
{
1
,
2
,
4
},
/*input_min=*/
-
10
,
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
-
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
Invoke
();
EXPECT_THAT
(
...
...
@@ -124,8 +121,9 @@ TEST(LayerNormModelTest, NegativeOffset) {
// Standard deviation values are 3.74, 4.41
const
std
::
vector
<
float
>
expected_output
=
{
-
1.0
,
-
2.6
,
-
0.53
,
0.07
,
-
1.0
,
-
2.13
,
0.59
,
-
1.45
};
LayerNormModel
m
({
TensorType_UINT8
,
{
1
,
2
,
4
},
-
10
,
10
},
-
10
,
10
,
1.0
,
-
1.0
,
{
1
},
{
2
});
LayerNormModel
m
(
/*input_shape=*/
{
1
,
2
,
4
},
/*input_min=*/
-
10
,
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
1.0
,
/*offset=*/
-
1.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
Invoke
();
EXPECT_THAT
(
...
...
@@ -142,8 +140,9 @@ TEST(LayerNormModelTest, NegativeScaleAndOffset) {
// Standard deviation values are 3.74, 4.41
const
std
::
vector
<
float
>
expected_output
=
{
-
1.0
,
0.6
,
-
1.53
,
-
2.07
,
-
1.0
,
0.13
,
-
2.59
,
-
0.55
};
LayerNormModel
m
({
TensorType_UINT8
,
{
1
,
2
,
4
},
-
10
,
10
},
-
10
,
10
,
-
1.0
,
-
1.0
,
{
1
},
{
2
});
LayerNormModel
m
(
/*input_shape=*/
{
1
,
2
,
4
},
/*input_min=*/
-
10
,
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
-
1.0
,
/*offset=*/
-
1.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
Invoke
();
EXPECT_THAT
(
...
...
@@ -160,8 +159,9 @@ TEST(LayerNormModelTest, MultipleAxis) {
1.12
,
-
2.08
,
0.48
,
-
0.16
,
-
0.95
,
-
1.46
,
-
0.95
,
0.06
,
-
0.69
,
-
0.23
,
-
1.60
,
-
1.15
,
-
0.80
,
-
0.16
,
0.48
,
1.12
};
LayerNormModel
m
({
TensorType_UINT8
,
{
1
,
2
,
3
,
4
},
-
3
,
3
},
-
3
,
3
,
1.0
,
0.0
,
{
2
},
{
1
,
3
});
LayerNormModel
m
(
/*input_shape=*/
{
1
,
2
,
3
,
4
},
/*input_min=*/
-
3
,
/*input_max=*/
3
,
/*output_min=*/
-
3
,
/*output_max=*/
3
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
1
,
3
});
m
.
SetInput
(
input
);
m
.
Invoke
();
EXPECT_THAT
(
...
...
@@ -178,8 +178,9 @@ TEST(LayerNormModelTest, MultipleNegativeAxis) {
1.12
,
-
2.08
,
0.48
,
-
0.16
,
-
0.95
,
-
1.46
,
-
0.95
,
0.06
,
-
0.69
,
-
0.23
,
-
1.60
,
-
1.15
,
-
0.80
,
-
0.16
,
0.48
,
1.12
};
LayerNormModel
m
({
TensorType_UINT8
,
{
1
,
2
,
3
,
4
},
-
3
,
3
},
-
3
,
3
,
1.0
,
0.0
,
{
2
},
{
-
3
,
-
1
});
LayerNormModel
m
(
/*input_shape=*/
{
1
,
2
,
3
,
4
},
/*input_min=*/
-
3
,
/*input_max=*/
3
,
/*output_min=*/
-
3
,
/*output_max=*/
3
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
-
3
,
-
1
});
m
.
SetInput
(
input
);
m
.
Invoke
();
EXPECT_THAT
(
...
...
@@ -199,8 +200,9 @@ TEST(LayerNormModelTest, MultipleAxisWithLargeDepth) {
2.05
,
2.05
,
-
0.67
,
-
0.28
,
1.27
,
1.27
,
-
1.06
,
-
1.06
,
-
0.28
,
0.
,
-
0.85
,
-
0.42
,
0.
,
0.42
,
-
0.85
,
-
0.42
,
0.
,
0.42
};
LayerNormModel
m
({
TensorType_UINT8
,
{
1
,
2
,
2
,
9
},
-
1.0
,
1.0
},
-
3.0
,
3.0
,
1.0
,
0.0
,
{
2
},
{
1
,
3
});
LayerNormModel
m
(
/*input_shape=*/
{
1
,
2
,
2
,
9
},
/*input_min=*/
-
1.0
,
/*input_max=*/
1.0
,
/*output_min=*/
-
3.0
,
/*output_max=*/
3.0
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
1
,
3
});
m
.
SetInput
(
input
);
m
.
Invoke
();
EXPECT_THAT
(
...
...
@@ -211,4 +213,4 @@ TEST(LayerNormModelTest, MultipleAxisWithLargeDepth) {
}
// namespace
}
// namespace custom
}
// namespace ops
}
// namespace
tf
lite
}
// namespace
seq_flow_
lite
research/seq_flow_lite/tflite_ops/quantization_util.h
View file @
fec0338f
...
...
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/lite/context.h"
namespace
tf
lite
{
namespace
seq_flow_
lite
{
// Returns the original (dequantized) value of 8bit value.
inline
float
PodDequantizeValue
(
const
TfLiteTensor
&
tensor
,
uint8_t
value
)
{
...
...
@@ -48,6 +48,6 @@ inline uint8_t PodQuantize(float value, int32_t zero_point,
return
static_cast
<
uint8_t
>
(
std
::
max
(
std
::
min
(
255
,
integer_value
),
0
));
}
}
// namespace
tf
lite
}
// namespace
seq_flow_
lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
research/seq_flow_lite/tflite_ops/registerer.cc
0 → 100644
View file @
fec0338f
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#include "tflite_ops/expected_value.h" // seq_flow_lite
#include "tflite_ops/layer_norm.h" // seq_flow_lite
#include "tflite_ops/sequence_string_projection.h" // seq_flow_lite
PYBIND11_MODULE
(
registerer
,
m
)
{
m
.
doc
()
=
"Module that provides a registerer from the seq flow lite custom ops"
;
m
.
def
(
"RegisterCustomOps"
,
[](
uintptr_t
ptr
)
{
::
tflite
::
MutableOpResolver
*
resolver
=
reinterpret_cast
<::
tflite
::
MutableOpResolver
*>
(
ptr
);
resolver
->
AddCustom
(
"ExpectedValueOp"
,
::
seq_flow_lite
::
ops
::
custom
::
Register_EXPECTED_VALUE
());
resolver
->
AddCustom
(
"LayerNorm"
,
::
seq_flow_lite
::
ops
::
custom
::
Register_LAYER_NORM
());
resolver
->
AddCustom
(
"SEQUENCE_STRING_PROJECTION"
,
::
seq_flow_lite
::
ops
::
custom
::
Register_SEQUENCE_STRING_PROJECTION
());
resolver
->
AddCustom
(
"SequenceStringProjection"
,
::
seq_flow_lite
::
ops
::
custom
::
Register_SEQUENCE_STRING_PROJECTION
());
resolver
->
AddCustom
(
"SEQUENCE_STRING_PROJECTION_V2"
,
::
seq_flow_lite
::
ops
::
custom
::
Register_SEQUENCE_STRING_PROJECTION
());
resolver
->
AddCustom
(
"SequenceStringProjectionV2"
,
::
seq_flow_lite
::
ops
::
custom
::
Register_SEQUENCE_STRING_PROJECTION_V2
());
},
"Register custom ops used by seq flow lite layers"
);
}
research/seq_flow_lite/tflite_ops/sequence_string_projection.cc
View file @
fec0338f
...
...
@@ -31,7 +31,7 @@ limitations under the License.
#include "tf_ops/projection_util.h" // seq_flow_lite
#include "tflite_ops/quantization_util.h" // seq_flow_lite
namespace
tf
lite
{
namespace
seq_flow_
lite
{
namespace
ops
{
namespace
custom
{
...
...
@@ -163,7 +163,7 @@ class ProjectionParams {
DocSizeFeature
(
&
doc_size_feature
,
num_tokens
);
*
data
=
PodQuantize
(
doc_size_feature
,
127.0
f
,
127
);
}
void
Hash
(
const
std
::
string
&
word
,
std
::
vector
<
uint64_t
>
*
hash_codes
)
{
void
Hash
(
const
std
::
string
&
word
,
std
::
vector
<
uint64_t
>
&
hash_codes
)
{
hasher_
->
GetHashCodes
(
word
,
hash_codes
);
}
// Lower cases the input text and eliminates all unsupported
...
...
@@ -269,6 +269,8 @@ class ProjectionParamsV2 : public ProjectionParams {
num_tokens
,
dims
->
data
[
1
]);
return
kTfLiteError
;
}
tokens_
.
clear
();
tokens_
.
reserve
(
num_tokens
);
for
(
int
i
=
0
;
i
<
num_tokens
;
++
i
)
{
const
tflite
::
StringRef
strref
=
tflite
::
GetString
(
input_t
,
i
);
tokens_
.
push_back
(
std
::
pair
<
const
char
*
,
size_t
>
(
strref
.
str
,
strref
.
len
));
...
...
@@ -412,7 +414,7 @@ void TypedEval(const T* mapping_table, ProjectionParams* params, T* data) {
}
else
{
word
=
kEndToken
;
}
params
->
Hash
(
word
,
&
hash_codes
);
params
->
Hash
(
word
,
hash_codes
);
for
(
int
hindex
=
0
,
k
=
0
;
hindex
<
hash_codes
.
size
();
hindex
++
)
{
auto
hash
=
hash_codes
[
hindex
];
for
(
int
kmax
=
std
::
min
(
k
+
kIncrement
,
params
->
FeatureSize
());
...
...
@@ -505,4 +507,4 @@ TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION_V2() {
}
// namespace custom
}
// namespace ops
}
// namespace
tf
lite
}
// namespace
seq_flow_
lite
research/seq_flow_lite/tflite_ops/sequence_string_projection.h
View file @
fec0338f
...
...
@@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#include "tensorflow/lite/kernels/register.h"
namespace
tf
lite
{
namespace
seq_flow_
lite
{
namespace
ops
{
namespace
custom
{
...
...
@@ -29,6 +29,6 @@ extern const char kSequenceStringProjectionV2[];
TfLiteRegistration
*
Register_SEQUENCE_STRING_PROJECTION_V2
();
}
// namespace custom
}
// namespace ops
}
// namespace
tf
lite
}
// namespace
seq_flow_
lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
research/seq_flow_lite/tflite_ops/sequence_string_projection_test.cc
View file @
fec0338f
...
...
@@ -25,29 +25,32 @@ limitations under the License.
#include "tf_ops/projection_util.h" // seq_flow_lite
#include "tflite_ops/tf_tflite_diff_test_util.h" // seq_flow_lite
namespace
tf
lite
{
namespace
seq_flow_
lite
{
namespace
ops
{
namespace
custom
{
namespace
{
using
::
seq_flow_lite
::
testing
::
AttrValue
;
using
::
seq_flow_lite
::
testing
::
FloatTensor
;
using
::
seq_flow_lite
::
testing
::
IntTensor
;
using
::
seq_flow_lite
::
testing
::
OpEquivTestCase
;
using
::
seq_flow_lite
::
testing
::
StringTensor
;
using
::
seq_flow_lite
::
testing
::
TensorflowTfLiteOpTest
;
using
::
testing
::
ElementsAreArray
;
using
::
tflite
::
testing
::
AttrValue
;
using
::
tflite
::
testing
::
FloatTensor
;
using
::
tflite
::
testing
::
IntTensor
;
using
::
tflite
::
testing
::
OpEquivTestCase
;
using
::
tflite
::
testing
::
StringTensor
;
using
::
tflite
::
testing
::
TensorflowTfLiteOpTest
;
class
SequenceStringProjectionModel
:
public
SingleOpModel
{
using
::
tflite
::
TensorType_FLOAT32
;
using
::
tflite
::
TensorType_STRING
;
using
::
tflite
::
TensorType_UINT8
;
class
SequenceStringProjectionModel
:
public
::
tflite
::
SingleOpModel
{
public:
explicit
SequenceStringProjectionModel
(
bool
split_on_space
,
int
max_splits
,
int
word_novelty_bits
,
int
doc_size_levels
,
bool
add_eos_tag
,
TensorType
output_type
,
int
doc_size_levels
,
bool
add_eos_tag
,
::
tflite
::
TensorType
output_type
,
const
std
::
string
&
token_separators
=
""
,
bool
normalize_repetition
=
false
,
float
add_first_cap
=
0.0
,
float
add_all_caps
=
0.0
,
const
string
&
hashtype
=
kMurmurHash
)
{
float
add_all_caps
=
0.0
,
const
std
::
string
&
hashtype
=
kMurmurHash
)
{
flexbuffers
::
Builder
fbb
;
fbb
.
Map
([
&
]
{
fbb
.
Int
(
"feature_size"
,
4
);
...
...
@@ -798,11 +801,11 @@ INSTANTIATE_TEST_SUITE_P(
SequenceStringProjectionTests
,
SequenceStringProjectionTest
,
::
testing
::
ValuesIn
(
SequenceStringProjectionTestCases
()));
class
SequenceStringProjectionV2Model
:
public
SingleOpModel
{
class
SequenceStringProjectionV2Model
:
public
::
tflite
::
SingleOpModel
{
public:
explicit
SequenceStringProjectionV2Model
(
std
::
vector
<
std
::
vector
<
int
>>
input_shapes
,
const
string
&
hashtype
=
kMurmurHash
)
{
const
std
::
string
&
hashtype
=
kMurmurHash
)
{
flexbuffers
::
Builder
fbb
;
fbb
.
Map
([
&
]
{
fbb
.
Int
(
"feature_size"
,
4
);
...
...
@@ -827,6 +830,7 @@ class SequenceStringProjectionV2Model : public SingleOpModel {
<<
"Cannot allocate tensors"
;
return
SingleOpModel
::
InvokeUnchecked
();
}
std
::
vector
<
int
>
GetOutputShape
()
{
return
GetTensorShape
(
output_
);
}
private:
int
input_
;
...
...
@@ -884,6 +888,15 @@ TEST(SequenceStringProjectionV2Test, RegularInputUint8) {
m
.
Invoke
({
"hello"
,
"world"
},
kTfLiteOk
);
}
TEST
(
SequenceStringProjectionV2Test
,
NumberProjectionsForMultipleInputs
)
{
SequenceStringProjectionV2Model
m
({{
1
,
2
}});
std
::
vector
<
std
::
string
>
input
=
{
"hello"
,
"world"
};
m
.
Invoke
(
input
,
kTfLiteOk
);
EXPECT_EQ
(
m
.
GetOutputShape
()[
1
],
input
.
size
());
m
.
Invoke
(
input
,
kTfLiteOk
);
EXPECT_EQ
(
m
.
GetOutputShape
()[
1
],
input
.
size
());
}
class
SequenceStringProjectionV2Test
:
public
TensorflowTfLiteOpTest
{
std
::
function
<
TfLiteRegistration
*
()
>
TfLiteOpRegistration
()
override
{
return
ops
::
custom
::
Register_SEQUENCE_STRING_PROJECTION_V2
;
...
...
@@ -986,7 +999,7 @@ INSTANTIATE_TEST_SUITE_P(
}
// namespace
}
// namespace custom
}
// namespace ops
}
// namespace
tf
lite
}
// namespace
seq_flow_
lite
int
main
(
int
argc
,
char
**
argv
)
{
// On Linux, add: absl::SetFlag(&FLAGS_logtostderr, true);
...
...
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.cc
View file @
fec0338f
...
...
@@ -19,11 +19,16 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace
tf
lite
{
namespace
seq_flow_
lite
{
namespace
testing
{
using
::
tensorflow
::
TensorProto
;
using
::
testing
::
FloatNear
;
using
::
tflite
::
TensorType_STRING
;
using
::
tflite
::
TensorType_UINT8
;
using
::
tflite
::
TensorType_INT32
;
using
::
tflite
::
TensorType_BOOL
;
using
::
tflite
::
TensorType_FLOAT32
;
::
tflite
::
TensorType
TfTypeToTfLiteType
(
::
tensorflow
::
DataType
dtype
)
{
switch
(
dtype
)
{
...
...
@@ -324,7 +329,7 @@ void TensorflowTfLiteOpTest::CompareOpOutput() {
const
auto
&
quantization_params
=
GetParam
().
output_tensors
[
i
].
quantization_params
;
if
(
quantization_params
.
scale
!=
0.0
)
{
auto
tflite_output_values
=
Dequantize
(
auto
tflite_output_values
=
tflite
::
Dequantize
(
tflite_op_
.
ExtractVector
<
uint8_t
>
(
tflite_outputs_
[
i
]),
quantization_params
.
scale
,
quantization_params
.
zero_point
);
for
(
int
i
=
0
;
i
<
tf_output_values
.
size
();
i
++
)
{
...
...
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.h
View file @
fec0338f
...
...
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/lite/kernels/test_util.h"
namespace
tf
lite
{
namespace
seq_flow_
lite
{
namespace
testing
{
// Convenience constructors.
...
...
@@ -144,6 +144,6 @@ class TensorflowTfLiteOpTest
};
}
// namespace testing
}
// namespace
tf
lite
}
// namespace
seq_flow_
lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
research/seq_flow_lite/trainer_v2.py
0 → 100644
View file @
fec0338f
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Binary to train PRADO model with TF 2.0."""
import
importlib
import
json
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
input_fn_reader
# import root module
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"config_path"
,
None
,
"Path to a RunnerConfig."
)
flags
.
DEFINE_enum
(
"runner_mode"
,
"train"
,
[
"train"
,
"train_and_eval"
,
"eval"
],
"Runner mode."
)
flags
.
DEFINE_string
(
"master"
,
None
,
"TensorFlow master URL."
)
flags
.
DEFINE_string
(
"output_dir"
,
"/tmp/testV2"
,
"The output directory where the model checkpoints will be written."
)
flags
.
DEFINE_bool
(
"use_tpu"
,
False
,
"Whether to use TPU or GPU/CPU."
)
flags
.
DEFINE_integer
(
"num_tpu_cores"
,
8
,
"Only used if `use_tpu` is True. Total number of TPU cores to use."
)
def
load_runner_config
():
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
config_path
,
"r"
)
as
f
:
return
json
.
loads
(
f
.
read
())
def
compute_loss
(
logits
,
labels
,
model_config
,
mode
):
"""Creates a sequence labeling model."""
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
if
not
model_config
[
"multilabel"
]:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
labels
,
logits
=
logits
)
else
:
loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
labels
,
logits
=
logits
)
loss
=
tf
.
reduce_mean
(
loss
)
else
:
loss
=
None
return
loss
def
model_fn_builder
(
runner_config
,
mode
):
"""Returns `model_fn` closure for TPUEstimator."""
rel_module_path
=
""
# empty base dir
model
=
importlib
.
import_module
(
rel_module_path
+
runner_config
[
"name"
])
model_config
=
runner_config
[
"model_config"
]
return
model
.
Encoder
(
model_config
,
mode
)
def
main
(
_
):
runner_config
=
load_runner_config
()
if
FLAGS
.
output_dir
:
tf
.
io
.
gfile
.
makedirs
(
FLAGS
.
output_dir
)
train_model
=
model_fn_builder
(
runner_config
,
tf
.
estimator
.
ModeKeys
.
TRAIN
)
optimizer
=
tf
.
keras
.
optimizers
.
Adam
()
train_input_fn
=
input_fn_reader
.
create_input_fn
(
runner_config
=
runner_config
,
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
,
drop_remainder
=
True
)
params
=
{
"batch_size"
:
runner_config
[
"batch_size"
]}
train_ds
=
train_input_fn
(
params
)
train_loss
=
tf
.
keras
.
metrics
.
Mean
(
name
=
"train_loss"
)
@
tf
.
function
def
train_step
(
features
):
with
tf
.
GradientTape
()
as
tape
:
logits
=
train_model
(
features
[
"projection"
],
features
[
"seq_length"
])
loss
=
compute_loss
(
logits
,
features
[
"label"
],
runner_config
[
"model_config"
],
tf
.
estimator
.
ModeKeys
.
TRAIN
)
gradients
=
tape
.
gradient
(
loss
,
train_model
.
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
gradients
,
train_model
.
trainable_variables
))
train_loss
(
loss
)
for
epoch
in
range
(
1
):
train_loss
.
reset_states
()
for
features
in
train_ds
:
train_step
(
features
)
step
=
optimizer
.
iterations
.
numpy
()
if
step
%
100
==
0
:
logging
.
info
(
"Running step %s in epoch %s"
,
step
,
epoch
)
logging
.
info
(
"Training loss: %s, epoch: %s, step: %s"
,
round
(
train_loss
.
result
().
numpy
(),
4
),
epoch
,
step
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
research/seq_flow_lite/utils/tflite_utils.py
View file @
fec0338f
...
...
@@ -29,13 +29,49 @@ def _dump_graph_in_text_format(filename, graph_def):
class
InterpreterWithCustomOps
(
tf
.
lite
.
Interpreter
):
"""Extended tf.lite.Interpreter."""
def
__init__
(
self
,
model_content
,
custom_op_registerers
):
self
.
_custom_op_registerers
=
custom_op_registerers
def
__init__
(
self
,
model_content
,
custom_op_registerers
=
None
):
self
.
_custom_op_registerers
=
custom_op_registerers
or
[]
super
(
InterpreterWithCustomOps
,
self
).
__init__
(
model_content
=
model_content
)
def
op_details
(
self
):
op_details
=
{}
try
:
op_details
=
self
.
_get_ops_details
()
# Accessing experimental method.
except
AttributeError
:
print
(
'Unable to access op details'
)
return
op_details
def
set_output_quantized_for_custom_ops
(
graph_def
):
def
op_histogram
(
self
):
op_hist
=
{}
op_list
=
self
.
op_details
()
for
op
in
op_list
:
if
op
[
'op_name'
]
in
op_hist
:
op_hist
[
op
[
'op_name'
]]
+=
1
else
:
op_hist
[
op
[
'op_name'
]]
=
1
return
op_hist
def
check_op_histogram
(
self
,
expected
):
passed
=
True
for
k
,
v
in
self
.
op_histogram
().
items
():
if
k
not
in
expected
:
print
(
'Unexpected key {} found {} times.'
.
format
(
k
,
v
))
passed
=
False
continue
elif
expected
[
k
]
!=
v
:
print
(
'Expected {} counts of key {} found {}.'
.
format
(
expected
[
k
],
k
,
v
))
passed
=
False
del
expected
[
k
]
for
k
,
v
in
expected
.
items
():
print
(
'Missing expected key {} value {}.'
.
format
(
k
,
v
))
passed
=
False
return
passed
def
set_output_quantized_for_custom_ops
(
graph_def
,
use_mlir
=
True
):
"""Set output types/quantized flag for custom/unsupported ops."""
quantized_custom_ops
=
{
'SequenceStringProjection'
:
[
tf
.
float32
.
as_datatype_enum
],
...
...
@@ -44,6 +80,8 @@ def set_output_quantized_for_custom_ops(graph_def):
'ExpectedValueOp'
:
[
tf
.
float32
.
as_datatype_enum
],
'LayerNorm'
:
[
tf
.
float32
.
as_datatype_enum
],
'UniformCausalAttn'
:
[
tf
.
float32
.
as_datatype_enum
],
'RnnDecoderReadState'
:
[
tf
.
float32
.
as_datatype_enum
],
'RnnDecoderWriteState'
:
[
tf
.
float32
.
as_datatype_enum
],
}
custom_op_renames
=
{
'SequenceStringProjection'
:
'SEQUENCE_STRING_PROJECTION'
,
...
...
@@ -52,30 +90,27 @@ def set_output_quantized_for_custom_ops(graph_def):
for
node
in
graph_def
.
node
:
if
node
.
op
in
quantized_custom_ops
:
node
.
attr
[
'_output_quantized'
].
b
=
True
node
.
attr
[
'_output_types'
].
list
.
type
[:]
=
quantized_custom_ops
[
node
.
op
]
if
node
.
op
in
custom_op_renames
:
if
use_mlir
:
node
.
attr
[
'_tfl_quant_trait'
].
s
=
str
.
encode
(
'fully_quantizable'
)
else
:
node
.
attr
[
'_output_quantized'
].
b
=
True
node
.
attr
[
'_output_types'
].
list
.
type
[:]
=
quantized_custom_ops
[
node
.
op
]
if
not
use_mlir
and
node
.
op
in
custom_op_renames
:
node
.
op
=
custom_op_renames
[
node
.
op
]
def
generate_tflite
(
session
,
graph
,
input_tensors
,
output_tensors
):
def
generate_tflite
(
session
,
graph
,
input_tensors
,
output_tensors
,
use_mlir
=
True
):
"""Generate TFLite model from a session, graph and input/output tensors."""
output_nodes
=
[
tensor
.
name
.
split
(
':'
)[
0
]
for
tensor
in
output_tensors
]
graph_def
=
tf
.
graph_util
.
convert_variables_to_constants
(
session
,
graph
.
as_graph_def
(),
output_nodes
)
set_output_quantized_for_custom_ops
(
graph_def
)
# TODO(b/171063452): Bug needs to be fixed to handle this correctly.
# def _node_name(tensor):
# return tensor.name.split(':')[0]
set_output_quantized_for_custom_ops
(
graph_def
,
use_mlir
)
# input_arrays_with_shape = [
# (_node_name(tensor), None) for tensor in input_tensors
# ]
# output_arrays = [_node_name(tensor) for tensor in output_tensors]
# converter = tf.lite.TFLiteConverter(graph_def, None, None,
# input_arrays_with_shape, output_arrays)
converter
=
tf
.
lite
.
TFLiteConverter
(
graph_def
,
input_tensors
,
output_tensors
)
converter
.
inference_type
=
tf
.
uint8
converter
.
default_ranges_stats
=
(
127.5
,
127.5
)
...
...
@@ -83,5 +118,5 @@ def generate_tflite(session, graph, input_tensors, output_tensors):
tensor
.
op
.
name
:
(
127.5
,
127.5
)
for
tensor
in
input_tensors
}
converter
.
allow_custom_ops
=
True
converter
.
experimental_new_converter
=
False
converter
.
experimental_new_converter
=
use_mlir
return
converter
.
convert
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment