Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
20cc2190
Unverified
Commit
20cc2190
authored
Aug 24, 2022
by
pyoung2778
Committed by
GitHub
Aug 24, 2022
Browse files
Check in seq_flow_lite (#10750)
parent
fdecf385
Changes
62
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
166 additions
and
112 deletions
+166
-112
research/seq_flow_lite/tflite_ops/beam_search.cc
research/seq_flow_lite/tflite_ops/beam_search.cc
+8
-7
research/seq_flow_lite/tflite_ops/beam_search.h
research/seq_flow_lite/tflite_ops/beam_search.h
+4
-4
research/seq_flow_lite/tflite_ops/beam_search_test.cc
research/seq_flow_lite/tflite_ops/beam_search_test.cc
+13
-13
research/seq_flow_lite/tflite_ops/expected_value.h
research/seq_flow_lite/tflite_ops/expected_value.h
+3
-3
research/seq_flow_lite/tflite_ops/layer_norm.h
research/seq_flow_lite/tflite_ops/layer_norm.h
+3
-3
research/seq_flow_lite/tflite_ops/layer_norm_test.cc
research/seq_flow_lite/tflite_ops/layer_norm_test.cc
+7
-7
research/seq_flow_lite/tflite_ops/quantization_util.h
research/seq_flow_lite/tflite_ops/quantization_util.h
+3
-3
research/seq_flow_lite/tflite_ops/sequence_string_projection.cc
...ch/seq_flow_lite/tflite_ops/sequence_string_projection.cc
+8
-5
research/seq_flow_lite/tflite_ops/sequence_string_projection.h
...rch/seq_flow_lite/tflite_ops/sequence_string_projection.h
+4
-3
research/seq_flow_lite/tflite_ops/sequence_string_projection_test.cc
...q_flow_lite/tflite_ops/sequence_string_projection_test.cc
+49
-5
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.cc
...arch/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.cc
+1
-1
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.h
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.h
+3
-3
research/seq_flow_lite/tflite_ops/tflite_decoder_cache.h
research/seq_flow_lite/tflite_ops/tflite_decoder_cache.h
+5
-4
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.cc
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.cc
+7
-6
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.h
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.h
+6
-5
research/seq_flow_lite/tflite_ops/tflite_decoder_handler_test.cc
...h/seq_flow_lite/tflite_ops/tflite_decoder_handler_test.cc
+6
-6
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.cc
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.cc
+6
-4
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h
+7
-7
research/seq_flow_lite/trainer.py
research/seq_flow_lite/trainer.py
+18
-18
research/seq_flow_lite/trainer_v2.py
research/seq_flow_lite/trainer_v2.py
+5
-5
No files found.
research/seq_flow_lite/tflite_ops/beam_search.cc
View file @
20cc2190
...
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "
third_party/tensorflow_models/seq_flow_lite/
tflite_ops/beam_search.h"
#include "tflite_ops/beam_search.h"
// seq_flow_lite
#include <algorithm>
#include <cstdint>
...
...
@@ -21,10 +21,10 @@ limitations under the License.
#include <vector>
#include "base/logging.h"
#include "
third_party/
absl/strings/str_join.h"
#include "
third_party/
tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "
third_party/
tensorflow/lite/kernels/internal/types.h"
#include "
third_party/tensorflow_models/seq_flow_lite/
tflite_ops/quantization_util.h"
#include "absl/strings/str_join.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tflite_ops/quantization_util.h"
// seq_flow_lite
namespace
seq_flow_lite
{
namespace
ops
{
...
...
@@ -86,6 +86,7 @@ void SequenceTracker::AddSequence(const int32_t *begin, const int32_t *end,
std
::
vector
<
std
::
vector
<
int32_t
>>
SequenceTracker
::
GetTopBeams
()
{
std
::
vector
<
std
::
vector
<
int32_t
>>
return_value
;
return_value
.
reserve
(
terminated_topk_
.
size
());
for
(
const
auto
&
v
:
terminated_topk_
)
{
return_value
.
push_back
(
v
.
second
);
}
...
...
@@ -255,8 +256,8 @@ void BeamSearch::FindTopKQuantizedFromLogitsV1(const TfLiteTensor &tensor,
}
}
// Updating topk across all beams.
for
(
uint32_t
k
=
0
;
k
<
std
::
min
(
topk_k
,
num_classes_
);
++
k
)
{
const
uint32_t
curr_beam_index
=
curr_beam
_topk
[
k
]
&
kClassIndexMask
;
for
(
uint32_t
curr_beam
:
curr_beam_top
k
)
{
const
uint32_t
curr_beam_index
=
curr_beam
&
kClassIndexMask
;
const
uint32_t
index
=
j
*
num_classes_
+
curr_beam_index
;
const
float
log_prob
=
tensor
.
params
.
scale
*
beam_logits
[
curr_beam_index
]
-
precomputed
;
...
...
research/seq_flow_lite/tflite_ops/beam_search.h
View file @
20cc2190
...
...
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#define
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#include <cstdint>
#include <functional>
...
...
@@ -23,7 +23,7 @@ limitations under the License.
#include <set>
#include <vector>
#include "
third_party/
tensorflow/lite/c/common.h"
#include "tensorflow/lite/c/common.h"
namespace
seq_flow_lite
{
namespace
ops
{
...
...
@@ -110,4 +110,4 @@ class BeamSearch {
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
#endif //
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
research/seq_flow_lite/tflite_ops/beam_search_test.cc
View file @
20cc2190
...
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "
third_party/tensorflow_models/seq_flow_lite/
tflite_ops/beam_search.h"
#include "tflite_ops/beam_search.h"
// seq_flow_lite
#include <cstdint>
#include <functional>
...
...
@@ -21,17 +21,17 @@ limitations under the License.
#include <memory>
#include <vector>
#include
"testing/base/public
/gmock.h
"
#include
"
test
ing/base/public/guni
t.h
"
#include "
third_party/
absl/strings/str_join.h"
#include "
third_party/
tensorflow/lite/c/c_api_types.h"
#include "
third_party/
tensorflow/lite/c/common.h"
#include "
third_party/
tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "
third_party/
tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "
third_party/
tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "
third_party/
tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "
third_party/
tensorflow/lite/kernels/internal/types.h"
#include "
third_party/tensorflow_models/seq_flow_lite/
tflite_ops/quantization_util.h"
#include
<gmock
/gmock.h
>
#include
<g
test
/gtes
t.h
>
#include "absl/strings/str_join.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tflite_ops/quantization_util.h"
// seq_flow_lite
namespace
seq_flow_lite
{
namespace
ops
{
...
...
@@ -76,7 +76,7 @@ class BeamSearchImpl : public BeamSearch {
cur_cache
+
(
selected_beams
[
beam
]
*
NumClasses
());
for
(
int
j
=
0
;
j
<
NumClasses
();
++
j
,
index
++
)
{
next_cache
[
index
]
=
(
selected
[
j
]
+
next_cache
[
index
])
/
2
;
data_ptr
[
index
]
=
::
seq_flow_lite
::
PodQuantize
(
data_ptr
[
index
]
=
PodQuantize
(
next_cache
[
index
],
decoder_output_
->
params
.
zero_point
,
1.0
f
/
decoder_output_
->
params
.
scale
);
}
...
...
research/seq_flow_lite/tflite_ops/expected_value.h
View file @
20cc2190
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_EXPECTED_VALUE_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_EXPECTED_VALUE_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_EXPECTED_VALUE_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_EXPECTED_VALUE_H_
#include "tensorflow/lite/kernels/register.h"
...
...
@@ -27,4 +27,4 @@ TfLiteRegistration* Register_EXPECTED_VALUE();
}
// namespace ops
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_EXPECTED_VALUE_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_EXPECTED_VALUE_H_
research/seq_flow_lite/tflite_ops/layer_norm.h
View file @
20cc2190
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef
LEARNING_EXPANDER_POD_DEEP_POD
_TFLITE_
HANDLER
S_LAYER_NORM_H_
#define
LEARNING_EXPANDER_POD_DEEP_POD
_TFLITE_
HANDLER
S_LAYER_NORM_H_
#ifndef
TENSORFLOW_MODELS_SEQ_FLOW_LITE
_TFLITE_
OP
S_LAYER_NORM_H_
#define
TENSORFLOW_MODELS_SEQ_FLOW_LITE
_TFLITE_
OP
S_LAYER_NORM_H_
#include "tensorflow/lite/kernels/register.h"
...
...
@@ -27,4 +27,4 @@ TfLiteRegistration* Register_LAYER_NORM();
}
// namespace ops
}
// namespace seq_flow_lite
#endif //
LEARNING_EXPANDER_POD_DEEP_POD
_TFLITE_
HANDLER
S_LAYER_NORM_H_
#endif //
TENSORFLOW_MODELS_SEQ_FLOW_LITE
_TFLITE_
OP
S_LAYER_NORM_H_
research/seq_flow_lite/tflite_ops/layer_norm_test.cc
View file @
20cc2190
...
...
@@ -87,7 +87,7 @@ TEST(LayerNormModelTest, RegularInput) {
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
@@ -106,7 +106,7 @@ TEST(LayerNormModelTest, NegativeScale) {
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
-
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
@@ -125,7 +125,7 @@ TEST(LayerNormModelTest, NegativeOffset) {
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
1.0
,
/*offset=*/
-
1.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
@@ -144,7 +144,7 @@ TEST(LayerNormModelTest, NegativeScaleAndOffset) {
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
-
1.0
,
/*offset=*/
-
1.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
@@ -163,7 +163,7 @@ TEST(LayerNormModelTest, MultipleAxis) {
/*input_max=*/
3
,
/*output_min=*/
-
3
,
/*output_max=*/
3
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
1
,
3
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
@@ -182,7 +182,7 @@ TEST(LayerNormModelTest, MultipleNegativeAxis) {
/*input_max=*/
3
,
/*output_min=*/
-
3
,
/*output_max=*/
3
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
-
3
,
-
1
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
@@ -204,7 +204,7 @@ TEST(LayerNormModelTest, MultipleAxisWithLargeDepth) {
/*input_max=*/
1.0
,
/*output_min=*/
-
3.0
,
/*output_max=*/
3.0
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
1
,
3
});
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
research/seq_flow_lite/tflite_ops/quantization_util.h
View file @
20cc2190
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#include <algorithm>
#include <cmath>
...
...
@@ -50,4 +50,4 @@ inline uint8_t PodQuantize(float value, int32_t zero_point,
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_QUANTIZATION_UTIL_H_
research/seq_flow_lite/tflite_ops/sequence_string_projection.cc
View file @
20cc2190
...
...
@@ -101,7 +101,7 @@ class ProjectionParams {
bool
exclude_nonalphaspace_unicodes
,
const
std
::
string
&
token_separators
,
bool
normalize_repetition
,
bool
add_first_cap_feature
,
bool
add_all_caps_feature
)
bool
add_all_caps_feature
,
bool
normalize_spaces
)
:
feature_size_
(
feature_size
),
unicode_handler_
(
vocabulary
,
exclude_nonalphaspace_unicodes
),
hasher_
(
Hasher
::
CreateHasher
(
feature_size
,
hashtype
)),
...
...
@@ -130,9 +130,9 @@ class ProjectionParams {
}
word_novelty_offset_
=
2.0
f
/
(
1
<<
word_novelty_bits_
);
if
(
!
token_separators
.
empty
()
||
normalize_repetition
)
{
if
(
!
token_separators
.
empty
()
||
normalize_repetition
||
normalize_spaces
)
{
projection_normalizer_
=
std
::
make_unique
<
ProjectionNormalizer
>
(
token_separators
,
normalize_repetition
);
token_separators
,
normalize_repetition
,
normalize_spaces
);
}
}
virtual
~
ProjectionParams
()
{}
...
...
@@ -242,7 +242,8 @@ class ProjectionParamsV2 : public ProjectionParams {
/*exclude_nonalphaspace_unicodes = */
false
,
/*token_separators = */
""
,
normalize_repetition
,
/*add_first_cap_feature = */
false
,
/*add_all_caps_feature = */
false
)
{}
/*add_all_caps_feature = */
false
,
/*normalize_spaces = */
false
)
{}
~
ProjectionParamsV2
()
override
{}
TfLiteStatus
PreprocessInput
(
TfLiteTensor
*
input_t
,
...
...
@@ -341,6 +342,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const
std
::
string
token_separators
=
m
[
"token_separators"
].
IsNull
()
?
""
:
m
[
"token_separators"
].
ToString
();
const
bool
normalize_repetition
=
m
[
"normalize_repetition"
].
AsBool
();
const
bool
normalize_spaces
=
m
[
"normalize_spaces"
].
AsBool
();
if
(
!
Hasher
::
SupportedHashType
(
hashtype
))
{
context
->
ReportError
(
context
,
"Unsupported hashtype %s
\n
"
,
hashtype
.
c_str
());
...
...
@@ -354,7 +356,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
add_bos_tag
?
BosTag
::
kGenerate
:
BosTag
::
kNone
,
add_eos_tag
?
EosTag
::
kGenerate
:
EosTag
::
kNone
,
exclude_nonalphaspace_unicodes
,
token_separators
,
normalize_repetition
,
add_first_cap_feature
==
1.0
f
,
add_all_caps_feature
==
1.0
f
);
add_first_cap_feature
==
1.0
f
,
add_all_caps_feature
==
1.0
f
,
normalize_spaces
);
}
void
*
InitV2
(
TfLiteContext
*
context
,
const
char
*
buffer
,
size_t
length
)
{
...
...
research/seq_flow_lite/tflite_ops/sequence_string_projection.h
View file @
20cc2190
...
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#include "tensorflow/lite/kernels/register.h"
namespace
seq_flow_lite
{
...
...
@@ -27,8 +27,9 @@ TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION();
extern
const
char
kSequenceStringProjectionV2
[];
TfLiteRegistration
*
Register_SEQUENCE_STRING_PROJECTION_V2
();
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
research/seq_flow_lite/tflite_ops/sequence_string_projection_test.cc
View file @
20cc2190
...
...
@@ -39,6 +39,7 @@ using ::seq_flow_lite::testing::OpEquivTestCase;
using
::
seq_flow_lite
::
testing
::
StringTensor
;
using
::
seq_flow_lite
::
testing
::
TensorflowTfLiteOpTest
;
using
::
testing
::
ElementsAreArray
;
using
::
testing
::
Not
;
using
::
tflite
::
TensorType_FLOAT32
;
using
::
tflite
::
TensorType_STRING
;
using
::
tflite
::
TensorType_UINT8
;
...
...
@@ -50,7 +51,8 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
int
doc_size_levels
,
bool
add_eos_tag
,
::
tflite
::
TensorType
output_type
,
const
std
::
string
&
token_separators
=
""
,
bool
normalize_repetition
=
false
,
float
add_first_cap
=
0.0
,
float
add_all_caps
=
0.0
,
const
std
::
string
&
hashtype
=
kMurmurHash
)
{
float
add_all_caps
=
0.0
,
const
std
::
string
&
hashtype
=
kMurmurHash
,
bool
normalize_spaces
=
false
)
{
flexbuffers
::
Builder
fbb
;
fbb
.
Map
([
&
]
{
fbb
.
Int
(
"feature_size"
,
4
);
...
...
@@ -65,6 +67,7 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
fbb
.
Bool
(
"normalize_repetition"
,
normalize_repetition
);
fbb
.
Float
(
"add_first_cap_feature"
,
add_first_cap
);
fbb
.
Float
(
"add_all_caps_feature"
,
add_all_caps
);
fbb
.
Bool
(
"normalize_spaces"
,
normalize_spaces
);
});
fbb
.
Finish
();
output_
=
AddOutput
({
output_type
,
{}});
...
...
@@ -76,13 +79,13 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
PopulateStringTensor
(
input_
,
{
input
});
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
SingleOpModel
::
Invoke
();
CHECK_EQ
(
SingleOpModel
::
Invoke
()
,
kTfLiteOk
)
;
}
TfLiteStatus
InvokeFailable
(
const
std
::
string
&
input
)
{
PopulateStringTensor
(
input_
,
{
input
});
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
return
SingleOpModel
::
Invoke
Unchecked
();
return
SingleOpModel
::
Invoke
();
}
template
<
typename
T
>
...
...
@@ -335,6 +338,32 @@ TEST(SequenceStringProjectionTest, NormalizeRepetition) {
EXPECT_THAT
(
output1
,
ElementsAreArray
(
output2
));
}
TEST
(
SequenceStringProjectionTest
,
NormalizeSpaces
)
{
SequenceStringProjectionModel
model_nonormalize
(
false
,
-
1
,
0
,
0
,
false
,
TensorType_UINT8
,
""
,
false
,
0.0
,
0.0
,
kMurmurHash
,
false
);
SequenceStringProjectionModel
model_normalize
(
false
,
-
1
,
0
,
0
,
false
,
TensorType_UINT8
,
""
,
false
,
0.0
,
0.0
,
kMurmurHash
,
true
);
const
char
kNoExtraSpaces
[]
=
"Hello there."
;
const
char
kExtraSpaces
[]
=
" Hello there. "
;
model_nonormalize
.
Invoke
(
kNoExtraSpaces
);
auto
output_noextra_nonorm
=
model_nonormalize
.
GetOutput
<
uint8_t
>
();
model_nonormalize
.
Invoke
(
kExtraSpaces
);
auto
output_extra_nonorm
=
model_nonormalize
.
GetOutput
<
uint8_t
>
();
model_normalize
.
Invoke
(
kNoExtraSpaces
);
auto
output_noextra_norm
=
model_normalize
.
GetOutput
<
uint8_t
>
();
model_normalize
.
Invoke
(
kExtraSpaces
);
auto
output_extra_norm
=
model_normalize
.
GetOutput
<
uint8_t
>
();
EXPECT_THAT
(
output_noextra_nonorm
,
ElementsAreArray
(
output_noextra_norm
));
EXPECT_THAT
(
output_noextra_nonorm
,
ElementsAreArray
(
output_extra_norm
));
EXPECT_THAT
(
output_noextra_nonorm
,
Not
(
ElementsAreArray
(
output_extra_nonorm
)));
}
class
SequenceStringProjectionTest
:
public
TensorflowTfLiteOpTest
{
std
::
function
<
TfLiteRegistration
*
()
>
TfLiteOpRegistration
()
override
{
return
ops
::
custom
::
Register_SEQUENCE_STRING_PROJECTION
;
...
...
@@ -710,6 +739,7 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
test_case
.
output_tensors
.
emplace_back
(
FloatTensor
({},
{}),
kScale
,
kZero
);
test_cases
.
push_back
(
test_case
);
}
{
OpEquivTestCase
test_case
;
test_case
.
test_name
=
"NormalizeRepetition"
;
...
...
@@ -794,6 +824,20 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
test_cases
.
push_back
(
test_case
);
}
{
OpEquivTestCase
test_case
;
test_case
.
test_name
=
"NormalizeSpaces"
;
test_case
.
attributes
[
"vocabulary"
]
=
AttrValue
(
""
);
test_case
.
attributes
[
"split_on_space"
]
=
AttrValue
(
true
);
test_case
.
attributes
[
"feature_size"
]
=
AttrValue
(
8
);
test_case
.
attributes
[
"add_eos_tag"
]
=
AttrValue
(
false
);
test_case
.
attributes
[
"add_bos_tag"
]
=
AttrValue
(
false
);
test_case
.
attributes
[
"normalize_spaces"
]
=
AttrValue
(
true
);
test_case
.
input_tensors
.
push_back
(
StringTensor
({
1
},
{
" Hello there. "
}));
test_case
.
output_tensors
.
emplace_back
(
FloatTensor
({},
{}),
kScale
,
kZero
);
test_cases
.
push_back
(
test_case
);
}
return
test_cases
;
}
...
...
@@ -822,13 +866,13 @@ class SequenceStringProjectionV2Model : public ::tflite::SingleOpModel {
PopulateStringTensor
(
input_
,
input
);
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
ASSERT_EQ
(
SingleOpModel
::
Invoke
Unchecked
(),
expected
);
ASSERT_EQ
(
SingleOpModel
::
Invoke
(),
expected
);
}
TfLiteStatus
InvokeFailable
(
const
std
::
string
&
input
)
{
PopulateStringTensor
(
input_
,
{
input
});
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
return
SingleOpModel
::
Invoke
Unchecked
();
return
SingleOpModel
::
Invoke
();
}
std
::
vector
<
int
>
GetOutputShape
()
{
return
GetTensorShape
(
output_
);
}
...
...
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.cc
View file @
20cc2190
...
...
@@ -309,7 +309,7 @@ void TensorflowTfLiteOpTest::RunTfLiteOp() {
input_index
++
;
}
tflite_op_
.
Invoke
();
ASSERT_EQ
(
tflite_op_
.
Invoke
()
,
kTfLiteOk
)
;
}
void
TensorflowTfLiteOpTest
::
CompareOpOutput
()
{
...
...
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.h
View file @
20cc2190
...
...
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// Tests equivalence between TF and TFLite versions of an op.
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#include <string>
#include <vector>
...
...
@@ -146,4 +146,4 @@ class TensorflowTfLiteOpTest
}
// namespace testing
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
research/seq_flow_lite/tflite_ops/tflite_decoder_cache.h
View file @
20cc2190
...
...
@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#define
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#include <memory>
#include "third_party/tensorflow/lite/c/common.h"
#include "tensorflow/lite/c/common.h"
namespace
seq_flow_lite
{
namespace
ops
{
namespace
custom
{
...
...
@@ -113,4 +114,4 @@ class DynamicCacheOp {
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
#endif //
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.cc
View file @
20cc2190
...
...
@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "
third_party/tensorflow_models/seq_flow_lite/
tflite_ops/tflite_decoder_handler.h"
#include "tflite_ops/tflite_decoder_handler.h"
// seq_flow_lite
#include <cstdint>
#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/kernel_util.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_cache.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "tflite_ops/tflite_decoder_cache.h" // seq_flow_lite
namespace
seq_flow_lite
{
namespace
ops
{
...
...
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.h
View file @
20cc2190
...
...
@@ -13,18 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#define
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#include "
third_party/
tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register.h"
namespace
seq_flow_lite
{
namespace
ops
{
namespace
custom
{
TfLiteRegistration
*
Register_UNIFORM_CAUSAL_ATTENTION
();
}
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
#endif //
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
research/seq_flow_lite/tflite_ops/tflite_decoder_handler_test.cc
View file @
20cc2190
...
...
@@ -13,17 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "
third_party/tensorflow_models/seq_flow_lite/
tflite_ops/tflite_decoder_handler.h"
#include "tflite_ops/tflite_decoder_handler.h"
// seq_flow_lite
#include <cstdint>
#include <cstdlib>
#include <vector>
#include
"testing/base/public
/gmock.h
"
#include
"
test
ing/base/public/guni
t.h
"
#include "
third_party/flatbuffers/include/flatbuffers/flex
buffer
s.h"
#include "
third_party/
tensorflow/lite/c/common.h"
#include "
third_party/
tensorflow/lite/kernels/test_util.h"
#include
<gmock
/gmock.h
>
#include
<g
test
/gtes
t.h
>
#include "
flatbuffers/flexbuffers.h" // flat
buffer
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/test_util.h"
namespace
{
...
...
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.cc
View file @
20cc2190
...
...
@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "tflite_ops/tflite_qrnn_pooling.h" // seq_flow_lite
namespace
seq_flow_lite
{
namespace
ops
{
namespace
custom
{
namespace
{
...
...
@@ -126,9 +127,9 @@ TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
return
QRNNPooling
(
context
,
multiplier
,
constant
,
outputs
,
final_state
,
(
direction
->
data
.
uint8
[
0
]
==
kPoolingForward
));
}
}
// namespace
namespace
custom
{
const
char
kPoolingOp
[]
=
"PoolingOp"
;
void
RegisterQRNNPooling
(
::
tflite
::
ops
::
builtin
::
BuiltinOpResolver
*
resolver
)
{
...
...
@@ -141,4 +142,5 @@ TfLiteRegistration* Register_QRNN_POOLING() {
}
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h
View file @
20cc2190
...
...
@@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#include "
third_party/
absl/base/macros.h"
#include "
third_party/
tensorflow/lite/kernels/register.h"
#include "absl/base/macros.h"
#include "tensorflow/lite/kernels/register.h"
namespace
seq_flow_lite
{
namespace
ops
{
namespace
custom
{
extern
const
char
kPoolingOp
[];
...
...
@@ -27,7 +27,7 @@ extern const char kPoolingOp[];
TfLiteRegistration
*
Register_QRNN_POOLING
();
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
research/seq_flow_lite/trainer.py
View file @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""A utility for PRADO model to do train, eval, inference and model export."""
import
importlib
...
...
@@ -22,6 +21,7 @@ from absl import app
from
absl
import
flags
from
absl
import
logging
import
tensorflow.compat.v1
as
tf
from
tensorflow.compat.v1
import
estimator
as
tf_estimator
import
input_fn_reader
# import root module
import
metric_functions
# import root module
...
...
@@ -48,14 +48,14 @@ def load_runner_config():
return
json
.
loads
(
f
.
read
())
def
create_model
(
model
,
model_config
,
features
,
mode
):
def
create_model
(
model
,
model_config
,
features
,
mode
,
model_name
):
"""Creates a sequence labeling model."""
keras_model
=
model
.
Encoder
(
model_config
,
mode
)
if
"pqrnn"
in
model_name
:
logits
=
keras_model
(
features
[
"projection"
],
features
[
"seq_length"
])
else
:
logits
=
keras_model
(
features
[
"token_ids"
],
features
[
"token_len"
])
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
if
mode
!=
tf
_
estimator
.
ModeKeys
.
PREDICT
:
if
not
model_config
[
"multilabel"
]:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
features
[
"label"
],
logits
=
logits
)
...
...
@@ -94,33 +94,33 @@ def model_fn_builder(runner_config):
def
model_fn
(
features
,
mode
,
params
):
"""The `model_fn` for TPUEstimator."""
label_ids
=
None
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
if
mode
!=
tf
_
estimator
.
ModeKeys
.
PREDICT
:
label_ids
=
features
[
"label"
]
model_config
=
runner_config
[
"model_config"
]
loss
,
logits
=
create_model
(
model
,
model_config
,
features
,
mode
)
loss
,
logits
=
create_model
(
model
,
model_config
,
features
,
mode
,
runner_config
[
"name"
])
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
if
mode
==
tf
_
estimator
.
ModeKeys
.
TRAIN
:
train_op
=
create_optimizer
(
loss
,
runner_config
,
params
)
return
tf
.
compat
.
v1
.
estimator
.
tpu
.
TPUEstimatorSpec
(
return
tf
_
estimator
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
train_op
=
train_op
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
elif
mode
==
tf
_
estimator
.
ModeKeys
.
EVAL
:
if
not
runner_config
[
"model_config"
][
"multilabel"
]:
metric_fn
=
metric_functions
.
classification_metric
else
:
metric_fn
=
metric_functions
.
labeling_metric
eval_metrics
=
(
metric_fn
,
[
loss
,
label_ids
,
logits
])
return
tf
.
compat
.
v1
.
estimator
.
tpu
.
TPUEstimatorSpec
(
return
tf
_
estimator
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
eval_metrics
=
eval_metrics
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
elif
mode
==
tf
_
estimator
.
ModeKeys
.
PREDICT
:
predictions
=
{
"logits"
:
logits
}
if
not
runner_config
[
"model_config"
][
"multilabel"
]:
predictions
[
"predictions"
]
=
tf
.
nn
.
softmax
(
logits
)
else
:
predictions
[
"predictions"
]
=
tf
.
math
.
sigmoid
(
logits
)
return
tf
.
compat
.
v1
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
)
return
tf_estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
)
else
:
assert
False
,
"Expected to be called in TRAIN, EVAL, or PREDICT mode."
...
...
@@ -133,13 +133,13 @@ def main(_):
if
FLAGS
.
output_dir
:
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_dir
)
is_per_host
=
tf
.
estimator
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
run_config
=
tf
.
estimator
.
tpu
.
RunConfig
(
is_per_host
=
tf
_
estimator
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
run_config
=
tf
_
estimator
.
tpu
.
RunConfig
(
master
=
FLAGS
.
master
,
model_dir
=
FLAGS
.
output_dir
,
save_checkpoints_steps
=
runner_config
[
"save_checkpoints_steps"
],
keep_checkpoint_max
=
20
,
tpu_config
=
tf
.
estimator
.
tpu
.
TPUConfig
(
tpu_config
=
tf
_
estimator
.
tpu
.
TPUConfig
(
iterations_per_loop
=
runner_config
[
"iterations_per_loop"
],
num_shards
=
FLAGS
.
num_tpu_cores
,
per_host_input_for_training
=
is_per_host
))
...
...
@@ -149,7 +149,7 @@ def main(_):
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
batch_size
=
runner_config
[
"batch_size"
]
estimator
=
tf
.
estimator
.
tpu
.
TPUEstimator
(
estimator
=
tf
_
estimator
.
tpu
.
TPUEstimator
(
use_tpu
=
FLAGS
.
use_tpu
,
model_fn
=
model_fn
,
config
=
run_config
,
...
...
@@ -160,7 +160,7 @@ def main(_):
if
FLAGS
.
runner_mode
==
"train"
:
train_input_fn
=
input_fn_reader
.
create_input_fn
(
runner_config
=
runner_config
,
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
,
mode
=
tf
_
estimator
.
ModeKeys
.
TRAIN
,
drop_remainder
=
True
)
estimator
.
train
(
input_fn
=
train_input_fn
,
max_steps
=
runner_config
[
"train_steps"
])
...
...
@@ -168,7 +168,7 @@ def main(_):
# TPU needs fixed shapes, so if the last batch is smaller, we drop it.
eval_input_fn
=
input_fn_reader
.
create_input_fn
(
runner_config
=
runner_config
,
mode
=
tf
.
estimator
.
ModeKeys
.
EVAL
,
mode
=
tf
_
estimator
.
ModeKeys
.
EVAL
,
drop_remainder
=
True
)
for
_
in
tf
.
train
.
checkpoints_iterator
(
FLAGS
.
output_dir
,
timeout
=
600
):
...
...
research/seq_flow_lite/trainer_v2.py
View file @
20cc2190
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Binary to train PRADO model with TF 2.0."""
import
importlib
...
...
@@ -23,6 +22,7 @@ from absl import flags
from
absl
import
logging
import
tensorflow
as
tf
from
tensorflow
import
estimator
as
tf_estimator
import
input_fn_reader
# import root module
...
...
@@ -48,7 +48,7 @@ def load_runner_config():
def
compute_loss
(
logits
,
labels
,
model_config
,
mode
):
"""Creates a sequence labeling model."""
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
if
mode
!=
tf
_
estimator
.
ModeKeys
.
PREDICT
:
if
not
model_config
[
"multilabel"
]:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
labels
,
logits
=
logits
)
...
...
@@ -77,11 +77,11 @@ def main(_):
if
FLAGS
.
output_dir
:
tf
.
io
.
gfile
.
makedirs
(
FLAGS
.
output_dir
)
train_model
=
model_fn_builder
(
runner_config
,
tf
.
estimator
.
ModeKeys
.
TRAIN
)
train_model
=
model_fn_builder
(
runner_config
,
tf
_
estimator
.
ModeKeys
.
TRAIN
)
optimizer
=
tf
.
keras
.
optimizers
.
Adam
()
train_input_fn
=
input_fn_reader
.
create_input_fn
(
runner_config
=
runner_config
,
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
,
mode
=
tf
_
estimator
.
ModeKeys
.
TRAIN
,
drop_remainder
=
True
)
params
=
{
"batch_size"
:
runner_config
[
"batch_size"
]}
train_ds
=
train_input_fn
(
params
)
...
...
@@ -93,7 +93,7 @@ def main(_):
logits
=
train_model
(
features
[
"projection"
],
features
[
"seq_length"
])
loss
=
compute_loss
(
logits
,
features
[
"label"
],
runner_config
[
"model_config"
],
tf
.
estimator
.
ModeKeys
.
TRAIN
)
tf
_
estimator
.
ModeKeys
.
TRAIN
)
gradients
=
tape
.
gradient
(
loss
,
train_model
.
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
gradients
,
train_model
.
trainable_variables
))
train_loss
(
loss
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment