Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
20cc2190
Unverified
Commit
20cc2190
authored
Aug 24, 2022
by
pyoung2778
Committed by
GitHub
Aug 24, 2022
Browse files
Check in seq_flow_lite (#10750)
parent
fdecf385
Changes
62
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
166 additions
and
112 deletions
+166
-112
research/seq_flow_lite/tflite_ops/beam_search.cc
research/seq_flow_lite/tflite_ops/beam_search.cc
+8
-7
research/seq_flow_lite/tflite_ops/beam_search.h
research/seq_flow_lite/tflite_ops/beam_search.h
+4
-4
research/seq_flow_lite/tflite_ops/beam_search_test.cc
research/seq_flow_lite/tflite_ops/beam_search_test.cc
+13
-13
research/seq_flow_lite/tflite_ops/expected_value.h
research/seq_flow_lite/tflite_ops/expected_value.h
+3
-3
research/seq_flow_lite/tflite_ops/layer_norm.h
research/seq_flow_lite/tflite_ops/layer_norm.h
+3
-3
research/seq_flow_lite/tflite_ops/layer_norm_test.cc
research/seq_flow_lite/tflite_ops/layer_norm_test.cc
+7
-7
research/seq_flow_lite/tflite_ops/quantization_util.h
research/seq_flow_lite/tflite_ops/quantization_util.h
+3
-3
research/seq_flow_lite/tflite_ops/sequence_string_projection.cc
...ch/seq_flow_lite/tflite_ops/sequence_string_projection.cc
+8
-5
research/seq_flow_lite/tflite_ops/sequence_string_projection.h
...rch/seq_flow_lite/tflite_ops/sequence_string_projection.h
+4
-3
research/seq_flow_lite/tflite_ops/sequence_string_projection_test.cc
...q_flow_lite/tflite_ops/sequence_string_projection_test.cc
+49
-5
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.cc
...arch/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.cc
+1
-1
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.h
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.h
+3
-3
research/seq_flow_lite/tflite_ops/tflite_decoder_cache.h
research/seq_flow_lite/tflite_ops/tflite_decoder_cache.h
+5
-4
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.cc
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.cc
+7
-6
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.h
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.h
+6
-5
research/seq_flow_lite/tflite_ops/tflite_decoder_handler_test.cc
...h/seq_flow_lite/tflite_ops/tflite_decoder_handler_test.cc
+6
-6
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.cc
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.cc
+6
-4
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h
+7
-7
research/seq_flow_lite/trainer.py
research/seq_flow_lite/trainer.py
+18
-18
research/seq_flow_lite/trainer_v2.py
research/seq_flow_lite/trainer_v2.py
+5
-5
No files found.
research/seq_flow_lite/tflite_ops/beam_search.cc
View file @
20cc2190
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#include "
third_party/tensorflow_models/seq_flow_lite/
tflite_ops/beam_search.h"
#include "tflite_ops/beam_search.h"
// seq_flow_lite
#include <algorithm>
#include <algorithm>
#include <cstdint>
#include <cstdint>
...
@@ -21,10 +21,10 @@ limitations under the License.
...
@@ -21,10 +21,10 @@ limitations under the License.
#include <vector>
#include <vector>
#include "base/logging.h"
#include "base/logging.h"
#include "
third_party/
absl/strings/str_join.h"
#include "absl/strings/str_join.h"
#include "
third_party/
tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "
third_party/
tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "
third_party/tensorflow_models/seq_flow_lite/
tflite_ops/quantization_util.h"
#include "tflite_ops/quantization_util.h"
// seq_flow_lite
namespace
seq_flow_lite
{
namespace
seq_flow_lite
{
namespace
ops
{
namespace
ops
{
...
@@ -86,6 +86,7 @@ void SequenceTracker::AddSequence(const int32_t *begin, const int32_t *end,
...
@@ -86,6 +86,7 @@ void SequenceTracker::AddSequence(const int32_t *begin, const int32_t *end,
std
::
vector
<
std
::
vector
<
int32_t
>>
SequenceTracker
::
GetTopBeams
()
{
std
::
vector
<
std
::
vector
<
int32_t
>>
SequenceTracker
::
GetTopBeams
()
{
std
::
vector
<
std
::
vector
<
int32_t
>>
return_value
;
std
::
vector
<
std
::
vector
<
int32_t
>>
return_value
;
return_value
.
reserve
(
terminated_topk_
.
size
());
for
(
const
auto
&
v
:
terminated_topk_
)
{
for
(
const
auto
&
v
:
terminated_topk_
)
{
return_value
.
push_back
(
v
.
second
);
return_value
.
push_back
(
v
.
second
);
}
}
...
@@ -255,8 +256,8 @@ void BeamSearch::FindTopKQuantizedFromLogitsV1(const TfLiteTensor &tensor,
...
@@ -255,8 +256,8 @@ void BeamSearch::FindTopKQuantizedFromLogitsV1(const TfLiteTensor &tensor,
}
}
}
}
// Updating topk across all beams.
// Updating topk across all beams.
for
(
uint32_t
k
=
0
;
k
<
std
::
min
(
topk_k
,
num_classes_
);
++
k
)
{
for
(
uint32_t
curr_beam
:
curr_beam_top
k
)
{
const
uint32_t
curr_beam_index
=
curr_beam
_topk
[
k
]
&
kClassIndexMask
;
const
uint32_t
curr_beam_index
=
curr_beam
&
kClassIndexMask
;
const
uint32_t
index
=
j
*
num_classes_
+
curr_beam_index
;
const
uint32_t
index
=
j
*
num_classes_
+
curr_beam_index
;
const
float
log_prob
=
const
float
log_prob
=
tensor
.
params
.
scale
*
beam_logits
[
curr_beam_index
]
-
precomputed
;
tensor
.
params
.
scale
*
beam_logits
[
curr_beam_index
]
-
precomputed
;
...
...
research/seq_flow_lite/tflite_ops/beam_search.h
View file @
20cc2190
...
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
...
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#ifndef
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#define
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#include <cstdint>
#include <cstdint>
#include <functional>
#include <functional>
...
@@ -23,7 +23,7 @@ limitations under the License.
...
@@ -23,7 +23,7 @@ limitations under the License.
#include <set>
#include <set>
#include <vector>
#include <vector>
#include "
third_party/
tensorflow/lite/c/common.h"
#include "tensorflow/lite/c/common.h"
namespace
seq_flow_lite
{
namespace
seq_flow_lite
{
namespace
ops
{
namespace
ops
{
...
@@ -110,4 +110,4 @@ class BeamSearch {
...
@@ -110,4 +110,4 @@ class BeamSearch {
}
// namespace custom
}
// namespace custom
}
// namespace ops
}
// namespace ops
}
// namespace seq_flow_lite
}
// namespace seq_flow_lite
#endif //
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
research/seq_flow_lite/tflite_ops/beam_search_test.cc
View file @
20cc2190
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#include "
third_party/tensorflow_models/seq_flow_lite/
tflite_ops/beam_search.h"
#include "tflite_ops/beam_search.h"
// seq_flow_lite
#include <cstdint>
#include <cstdint>
#include <functional>
#include <functional>
...
@@ -21,17 +21,17 @@ limitations under the License.
...
@@ -21,17 +21,17 @@ limitations under the License.
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include
"testing/base/public
/gmock.h
"
#include
<gmock
/gmock.h
>
#include
"
test
ing/base/public/guni
t.h
"
#include
<g
test
/gtes
t.h
>
#include "
third_party/
absl/strings/str_join.h"
#include "absl/strings/str_join.h"
#include "
third_party/
tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "
third_party/
tensorflow/lite/c/common.h"
#include "tensorflow/lite/c/common.h"
#include "
third_party/
tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "
third_party/
tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "
third_party/
tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "
third_party/
tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "
third_party/
tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "
third_party/tensorflow_models/seq_flow_lite/
tflite_ops/quantization_util.h"
#include "tflite_ops/quantization_util.h"
// seq_flow_lite
namespace
seq_flow_lite
{
namespace
seq_flow_lite
{
namespace
ops
{
namespace
ops
{
...
@@ -76,7 +76,7 @@ class BeamSearchImpl : public BeamSearch {
...
@@ -76,7 +76,7 @@ class BeamSearchImpl : public BeamSearch {
cur_cache
+
(
selected_beams
[
beam
]
*
NumClasses
());
cur_cache
+
(
selected_beams
[
beam
]
*
NumClasses
());
for
(
int
j
=
0
;
j
<
NumClasses
();
++
j
,
index
++
)
{
for
(
int
j
=
0
;
j
<
NumClasses
();
++
j
,
index
++
)
{
next_cache
[
index
]
=
(
selected
[
j
]
+
next_cache
[
index
])
/
2
;
next_cache
[
index
]
=
(
selected
[
j
]
+
next_cache
[
index
])
/
2
;
data_ptr
[
index
]
=
::
seq_flow_lite
::
PodQuantize
(
data_ptr
[
index
]
=
PodQuantize
(
next_cache
[
index
],
decoder_output_
->
params
.
zero_point
,
next_cache
[
index
],
decoder_output_
->
params
.
zero_point
,
1.0
f
/
decoder_output_
->
params
.
scale
);
1.0
f
/
decoder_output_
->
params
.
scale
);
}
}
...
...
research/seq_flow_lite/tflite_ops/expected_value.h
View file @
20cc2190
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_EXPECTED_VALUE_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_EXPECTED_VALUE_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_EXPECTED_VALUE_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_EXPECTED_VALUE_H_
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register.h"
...
@@ -27,4 +27,4 @@ TfLiteRegistration* Register_EXPECTED_VALUE();
...
@@ -27,4 +27,4 @@ TfLiteRegistration* Register_EXPECTED_VALUE();
}
// namespace ops
}
// namespace ops
}
// namespace seq_flow_lite
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_EXPECTED_VALUE_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_EXPECTED_VALUE_H_
research/seq_flow_lite/tflite_ops/layer_norm.h
View file @
20cc2190
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#ifndef
LEARNING_EXPANDER_POD_DEEP_POD
_TFLITE_
HANDLER
S_LAYER_NORM_H_
#ifndef
TENSORFLOW_MODELS_SEQ_FLOW_LITE
_TFLITE_
OP
S_LAYER_NORM_H_
#define
LEARNING_EXPANDER_POD_DEEP_POD
_TFLITE_
HANDLER
S_LAYER_NORM_H_
#define
TENSORFLOW_MODELS_SEQ_FLOW_LITE
_TFLITE_
OP
S_LAYER_NORM_H_
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register.h"
...
@@ -27,4 +27,4 @@ TfLiteRegistration* Register_LAYER_NORM();
...
@@ -27,4 +27,4 @@ TfLiteRegistration* Register_LAYER_NORM();
}
// namespace ops
}
// namespace ops
}
// namespace seq_flow_lite
}
// namespace seq_flow_lite
#endif //
LEARNING_EXPANDER_POD_DEEP_POD
_TFLITE_
HANDLER
S_LAYER_NORM_H_
#endif //
TENSORFLOW_MODELS_SEQ_FLOW_LITE
_TFLITE_
OP
S_LAYER_NORM_H_
research/seq_flow_lite/tflite_ops/layer_norm_test.cc
View file @
20cc2190
...
@@ -87,7 +87,7 @@ TEST(LayerNormModelTest, RegularInput) {
...
@@ -87,7 +87,7 @@ TEST(LayerNormModelTest, RegularInput) {
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
2
});
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
@@ -106,7 +106,7 @@ TEST(LayerNormModelTest, NegativeScale) {
...
@@ -106,7 +106,7 @@ TEST(LayerNormModelTest, NegativeScale) {
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
-
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
2
});
/*scale=*/
-
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
@@ -125,7 +125,7 @@ TEST(LayerNormModelTest, NegativeOffset) {
...
@@ -125,7 +125,7 @@ TEST(LayerNormModelTest, NegativeOffset) {
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
1.0
,
/*offset=*/
-
1.0
,
/*axes=*/
{
2
});
/*scale=*/
1.0
,
/*offset=*/
-
1.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
@@ -144,7 +144,7 @@ TEST(LayerNormModelTest, NegativeScaleAndOffset) {
...
@@ -144,7 +144,7 @@ TEST(LayerNormModelTest, NegativeScaleAndOffset) {
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*input_max=*/
10
,
/*output_min=*/
-
10
,
/*output_max=*/
10
,
/*scale=*/
-
1.0
,
/*offset=*/
-
1.0
,
/*axes=*/
{
2
});
/*scale=*/
-
1.0
,
/*offset=*/
-
1.0
,
/*axes=*/
{
2
});
m
.
SetInput
(
input
);
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
@@ -163,7 +163,7 @@ TEST(LayerNormModelTest, MultipleAxis) {
...
@@ -163,7 +163,7 @@ TEST(LayerNormModelTest, MultipleAxis) {
/*input_max=*/
3
,
/*output_min=*/
-
3
,
/*output_max=*/
3
,
/*input_max=*/
3
,
/*output_min=*/
-
3
,
/*output_max=*/
3
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
1
,
3
});
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
1
,
3
});
m
.
SetInput
(
input
);
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
@@ -182,7 +182,7 @@ TEST(LayerNormModelTest, MultipleNegativeAxis) {
...
@@ -182,7 +182,7 @@ TEST(LayerNormModelTest, MultipleNegativeAxis) {
/*input_max=*/
3
,
/*output_min=*/
-
3
,
/*output_max=*/
3
,
/*input_max=*/
3
,
/*output_min=*/
-
3
,
/*output_max=*/
3
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
-
3
,
-
1
});
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
-
3
,
-
1
});
m
.
SetInput
(
input
);
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
@@ -204,7 +204,7 @@ TEST(LayerNormModelTest, MultipleAxisWithLargeDepth) {
...
@@ -204,7 +204,7 @@ TEST(LayerNormModelTest, MultipleAxisWithLargeDepth) {
/*input_max=*/
1.0
,
/*output_min=*/
-
3.0
,
/*output_max=*/
3.0
,
/*input_max=*/
1.0
,
/*output_min=*/
-
3.0
,
/*output_max=*/
3.0
,
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
1
,
3
});
/*scale=*/
1.0
,
/*offset=*/
0.0
,
/*axes=*/
{
1
,
3
});
m
.
SetInput
(
input
);
m
.
SetInput
(
input
);
m
.
Invoke
(
);
ASSERT_EQ
(
m
.
Invoke
(),
kTfLiteOk
);
EXPECT_THAT
(
EXPECT_THAT
(
m
.
GetDequantizedOutput
(),
m
.
GetDequantizedOutput
(),
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
ElementsAreArray
(
ArrayFloatNear
(
expected_output
,
kQuantizedTolerance
)));
...
...
research/seq_flow_lite/tflite_ops/quantization_util.h
View file @
20cc2190
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#include <algorithm>
#include <algorithm>
#include <cmath>
#include <cmath>
...
@@ -50,4 +50,4 @@ inline uint8_t PodQuantize(float value, int32_t zero_point,
...
@@ -50,4 +50,4 @@ inline uint8_t PodQuantize(float value, int32_t zero_point,
}
// namespace seq_flow_lite
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_QUANTIZATION_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_QUANTIZATION_UTIL_H_
research/seq_flow_lite/tflite_ops/sequence_string_projection.cc
View file @
20cc2190
...
@@ -101,7 +101,7 @@ class ProjectionParams {
...
@@ -101,7 +101,7 @@ class ProjectionParams {
bool
exclude_nonalphaspace_unicodes
,
bool
exclude_nonalphaspace_unicodes
,
const
std
::
string
&
token_separators
,
const
std
::
string
&
token_separators
,
bool
normalize_repetition
,
bool
add_first_cap_feature
,
bool
normalize_repetition
,
bool
add_first_cap_feature
,
bool
add_all_caps_feature
)
bool
add_all_caps_feature
,
bool
normalize_spaces
)
:
feature_size_
(
feature_size
),
:
feature_size_
(
feature_size
),
unicode_handler_
(
vocabulary
,
exclude_nonalphaspace_unicodes
),
unicode_handler_
(
vocabulary
,
exclude_nonalphaspace_unicodes
),
hasher_
(
Hasher
::
CreateHasher
(
feature_size
,
hashtype
)),
hasher_
(
Hasher
::
CreateHasher
(
feature_size
,
hashtype
)),
...
@@ -130,9 +130,9 @@ class ProjectionParams {
...
@@ -130,9 +130,9 @@ class ProjectionParams {
}
}
word_novelty_offset_
=
2.0
f
/
(
1
<<
word_novelty_bits_
);
word_novelty_offset_
=
2.0
f
/
(
1
<<
word_novelty_bits_
);
if
(
!
token_separators
.
empty
()
||
normalize_repetition
)
{
if
(
!
token_separators
.
empty
()
||
normalize_repetition
||
normalize_spaces
)
{
projection_normalizer_
=
std
::
make_unique
<
ProjectionNormalizer
>
(
projection_normalizer_
=
std
::
make_unique
<
ProjectionNormalizer
>
(
token_separators
,
normalize_repetition
);
token_separators
,
normalize_repetition
,
normalize_spaces
);
}
}
}
}
virtual
~
ProjectionParams
()
{}
virtual
~
ProjectionParams
()
{}
...
@@ -242,7 +242,8 @@ class ProjectionParamsV2 : public ProjectionParams {
...
@@ -242,7 +242,8 @@ class ProjectionParamsV2 : public ProjectionParams {
/*exclude_nonalphaspace_unicodes = */
false
,
/*exclude_nonalphaspace_unicodes = */
false
,
/*token_separators = */
""
,
normalize_repetition
,
/*token_separators = */
""
,
normalize_repetition
,
/*add_first_cap_feature = */
false
,
/*add_first_cap_feature = */
false
,
/*add_all_caps_feature = */
false
)
{}
/*add_all_caps_feature = */
false
,
/*normalize_spaces = */
false
)
{}
~
ProjectionParamsV2
()
override
{}
~
ProjectionParamsV2
()
override
{}
TfLiteStatus
PreprocessInput
(
TfLiteTensor
*
input_t
,
TfLiteStatus
PreprocessInput
(
TfLiteTensor
*
input_t
,
...
@@ -341,6 +342,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
...
@@ -341,6 +342,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const
std
::
string
token_separators
=
const
std
::
string
token_separators
=
m
[
"token_separators"
].
IsNull
()
?
""
:
m
[
"token_separators"
].
ToString
();
m
[
"token_separators"
].
IsNull
()
?
""
:
m
[
"token_separators"
].
ToString
();
const
bool
normalize_repetition
=
m
[
"normalize_repetition"
].
AsBool
();
const
bool
normalize_repetition
=
m
[
"normalize_repetition"
].
AsBool
();
const
bool
normalize_spaces
=
m
[
"normalize_spaces"
].
AsBool
();
if
(
!
Hasher
::
SupportedHashType
(
hashtype
))
{
if
(
!
Hasher
::
SupportedHashType
(
hashtype
))
{
context
->
ReportError
(
context
,
"Unsupported hashtype %s
\n
"
,
context
->
ReportError
(
context
,
"Unsupported hashtype %s
\n
"
,
hashtype
.
c_str
());
hashtype
.
c_str
());
...
@@ -354,7 +356,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
...
@@ -354,7 +356,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
add_bos_tag
?
BosTag
::
kGenerate
:
BosTag
::
kNone
,
add_bos_tag
?
BosTag
::
kGenerate
:
BosTag
::
kNone
,
add_eos_tag
?
EosTag
::
kGenerate
:
EosTag
::
kNone
,
add_eos_tag
?
EosTag
::
kGenerate
:
EosTag
::
kNone
,
exclude_nonalphaspace_unicodes
,
token_separators
,
normalize_repetition
,
exclude_nonalphaspace_unicodes
,
token_separators
,
normalize_repetition
,
add_first_cap_feature
==
1.0
f
,
add_all_caps_feature
==
1.0
f
);
add_first_cap_feature
==
1.0
f
,
add_all_caps_feature
==
1.0
f
,
normalize_spaces
);
}
}
void
*
InitV2
(
TfLiteContext
*
context
,
const
char
*
buffer
,
size_t
length
)
{
void
*
InitV2
(
TfLiteContext
*
context
,
const
char
*
buffer
,
size_t
length
)
{
...
...
research/seq_flow_lite/tflite_ops/sequence_string_projection.h
View file @
20cc2190
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register.h"
namespace
seq_flow_lite
{
namespace
seq_flow_lite
{
...
@@ -27,8 +27,9 @@ TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION();
...
@@ -27,8 +27,9 @@ TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION();
extern
const
char
kSequenceStringProjectionV2
[];
extern
const
char
kSequenceStringProjectionV2
[];
TfLiteRegistration
*
Register_SEQUENCE_STRING_PROJECTION_V2
();
TfLiteRegistration
*
Register_SEQUENCE_STRING_PROJECTION_V2
();
}
// namespace custom
}
// namespace custom
}
// namespace ops
}
// namespace ops
}
// namespace seq_flow_lite
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
research/seq_flow_lite/tflite_ops/sequence_string_projection_test.cc
View file @
20cc2190
...
@@ -39,6 +39,7 @@ using ::seq_flow_lite::testing::OpEquivTestCase;
...
@@ -39,6 +39,7 @@ using ::seq_flow_lite::testing::OpEquivTestCase;
using
::
seq_flow_lite
::
testing
::
StringTensor
;
using
::
seq_flow_lite
::
testing
::
StringTensor
;
using
::
seq_flow_lite
::
testing
::
TensorflowTfLiteOpTest
;
using
::
seq_flow_lite
::
testing
::
TensorflowTfLiteOpTest
;
using
::
testing
::
ElementsAreArray
;
using
::
testing
::
ElementsAreArray
;
using
::
testing
::
Not
;
using
::
tflite
::
TensorType_FLOAT32
;
using
::
tflite
::
TensorType_FLOAT32
;
using
::
tflite
::
TensorType_STRING
;
using
::
tflite
::
TensorType_STRING
;
using
::
tflite
::
TensorType_UINT8
;
using
::
tflite
::
TensorType_UINT8
;
...
@@ -50,7 +51,8 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
...
@@ -50,7 +51,8 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
int
doc_size_levels
,
bool
add_eos_tag
,
::
tflite
::
TensorType
output_type
,
int
doc_size_levels
,
bool
add_eos_tag
,
::
tflite
::
TensorType
output_type
,
const
std
::
string
&
token_separators
=
""
,
const
std
::
string
&
token_separators
=
""
,
bool
normalize_repetition
=
false
,
float
add_first_cap
=
0.0
,
bool
normalize_repetition
=
false
,
float
add_first_cap
=
0.0
,
float
add_all_caps
=
0.0
,
const
std
::
string
&
hashtype
=
kMurmurHash
)
{
float
add_all_caps
=
0.0
,
const
std
::
string
&
hashtype
=
kMurmurHash
,
bool
normalize_spaces
=
false
)
{
flexbuffers
::
Builder
fbb
;
flexbuffers
::
Builder
fbb
;
fbb
.
Map
([
&
]
{
fbb
.
Map
([
&
]
{
fbb
.
Int
(
"feature_size"
,
4
);
fbb
.
Int
(
"feature_size"
,
4
);
...
@@ -65,6 +67,7 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
...
@@ -65,6 +67,7 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
fbb
.
Bool
(
"normalize_repetition"
,
normalize_repetition
);
fbb
.
Bool
(
"normalize_repetition"
,
normalize_repetition
);
fbb
.
Float
(
"add_first_cap_feature"
,
add_first_cap
);
fbb
.
Float
(
"add_first_cap_feature"
,
add_first_cap
);
fbb
.
Float
(
"add_all_caps_feature"
,
add_all_caps
);
fbb
.
Float
(
"add_all_caps_feature"
,
add_all_caps
);
fbb
.
Bool
(
"normalize_spaces"
,
normalize_spaces
);
});
});
fbb
.
Finish
();
fbb
.
Finish
();
output_
=
AddOutput
({
output_type
,
{}});
output_
=
AddOutput
({
output_type
,
{}});
...
@@ -76,13 +79,13 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
...
@@ -76,13 +79,13 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
PopulateStringTensor
(
input_
,
{
input
});
PopulateStringTensor
(
input_
,
{
input
});
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
<<
"Cannot allocate tensors"
;
SingleOpModel
::
Invoke
();
CHECK_EQ
(
SingleOpModel
::
Invoke
()
,
kTfLiteOk
)
;
}
}
TfLiteStatus
InvokeFailable
(
const
std
::
string
&
input
)
{
TfLiteStatus
InvokeFailable
(
const
std
::
string
&
input
)
{
PopulateStringTensor
(
input_
,
{
input
});
PopulateStringTensor
(
input_
,
{
input
});
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
<<
"Cannot allocate tensors"
;
return
SingleOpModel
::
Invoke
Unchecked
();
return
SingleOpModel
::
Invoke
();
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -335,6 +338,32 @@ TEST(SequenceStringProjectionTest, NormalizeRepetition) {
...
@@ -335,6 +338,32 @@ TEST(SequenceStringProjectionTest, NormalizeRepetition) {
EXPECT_THAT
(
output1
,
ElementsAreArray
(
output2
));
EXPECT_THAT
(
output1
,
ElementsAreArray
(
output2
));
}
}
TEST
(
SequenceStringProjectionTest
,
NormalizeSpaces
)
{
SequenceStringProjectionModel
model_nonormalize
(
false
,
-
1
,
0
,
0
,
false
,
TensorType_UINT8
,
""
,
false
,
0.0
,
0.0
,
kMurmurHash
,
false
);
SequenceStringProjectionModel
model_normalize
(
false
,
-
1
,
0
,
0
,
false
,
TensorType_UINT8
,
""
,
false
,
0.0
,
0.0
,
kMurmurHash
,
true
);
const
char
kNoExtraSpaces
[]
=
"Hello there."
;
const
char
kExtraSpaces
[]
=
" Hello there. "
;
model_nonormalize
.
Invoke
(
kNoExtraSpaces
);
auto
output_noextra_nonorm
=
model_nonormalize
.
GetOutput
<
uint8_t
>
();
model_nonormalize
.
Invoke
(
kExtraSpaces
);
auto
output_extra_nonorm
=
model_nonormalize
.
GetOutput
<
uint8_t
>
();
model_normalize
.
Invoke
(
kNoExtraSpaces
);
auto
output_noextra_norm
=
model_normalize
.
GetOutput
<
uint8_t
>
();
model_normalize
.
Invoke
(
kExtraSpaces
);
auto
output_extra_norm
=
model_normalize
.
GetOutput
<
uint8_t
>
();
EXPECT_THAT
(
output_noextra_nonorm
,
ElementsAreArray
(
output_noextra_norm
));
EXPECT_THAT
(
output_noextra_nonorm
,
ElementsAreArray
(
output_extra_norm
));
EXPECT_THAT
(
output_noextra_nonorm
,
Not
(
ElementsAreArray
(
output_extra_nonorm
)));
}
class
SequenceStringProjectionTest
:
public
TensorflowTfLiteOpTest
{
class
SequenceStringProjectionTest
:
public
TensorflowTfLiteOpTest
{
std
::
function
<
TfLiteRegistration
*
()
>
TfLiteOpRegistration
()
override
{
std
::
function
<
TfLiteRegistration
*
()
>
TfLiteOpRegistration
()
override
{
return
ops
::
custom
::
Register_SEQUENCE_STRING_PROJECTION
;
return
ops
::
custom
::
Register_SEQUENCE_STRING_PROJECTION
;
...
@@ -710,6 +739,7 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
...
@@ -710,6 +739,7 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
test_case
.
output_tensors
.
emplace_back
(
FloatTensor
({},
{}),
kScale
,
kZero
);
test_case
.
output_tensors
.
emplace_back
(
FloatTensor
({},
{}),
kScale
,
kZero
);
test_cases
.
push_back
(
test_case
);
test_cases
.
push_back
(
test_case
);
}
}
{
{
OpEquivTestCase
test_case
;
OpEquivTestCase
test_case
;
test_case
.
test_name
=
"NormalizeRepetition"
;
test_case
.
test_name
=
"NormalizeRepetition"
;
...
@@ -794,6 +824,20 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
...
@@ -794,6 +824,20 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
test_cases
.
push_back
(
test_case
);
test_cases
.
push_back
(
test_case
);
}
}
{
OpEquivTestCase
test_case
;
test_case
.
test_name
=
"NormalizeSpaces"
;
test_case
.
attributes
[
"vocabulary"
]
=
AttrValue
(
""
);
test_case
.
attributes
[
"split_on_space"
]
=
AttrValue
(
true
);
test_case
.
attributes
[
"feature_size"
]
=
AttrValue
(
8
);
test_case
.
attributes
[
"add_eos_tag"
]
=
AttrValue
(
false
);
test_case
.
attributes
[
"add_bos_tag"
]
=
AttrValue
(
false
);
test_case
.
attributes
[
"normalize_spaces"
]
=
AttrValue
(
true
);
test_case
.
input_tensors
.
push_back
(
StringTensor
({
1
},
{
" Hello there. "
}));
test_case
.
output_tensors
.
emplace_back
(
FloatTensor
({},
{}),
kScale
,
kZero
);
test_cases
.
push_back
(
test_case
);
}
return
test_cases
;
return
test_cases
;
}
}
...
@@ -822,13 +866,13 @@ class SequenceStringProjectionV2Model : public ::tflite::SingleOpModel {
...
@@ -822,13 +866,13 @@ class SequenceStringProjectionV2Model : public ::tflite::SingleOpModel {
PopulateStringTensor
(
input_
,
input
);
PopulateStringTensor
(
input_
,
input
);
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
<<
"Cannot allocate tensors"
;
ASSERT_EQ
(
SingleOpModel
::
Invoke
Unchecked
(),
expected
);
ASSERT_EQ
(
SingleOpModel
::
Invoke
(),
expected
);
}
}
TfLiteStatus
InvokeFailable
(
const
std
::
string
&
input
)
{
TfLiteStatus
InvokeFailable
(
const
std
::
string
&
input
)
{
PopulateStringTensor
(
input_
,
{
input
});
PopulateStringTensor
(
input_
,
{
input
});
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
CHECK
(
interpreter_
->
AllocateTensors
()
==
kTfLiteOk
)
<<
"Cannot allocate tensors"
;
<<
"Cannot allocate tensors"
;
return
SingleOpModel
::
Invoke
Unchecked
();
return
SingleOpModel
::
Invoke
();
}
}
std
::
vector
<
int
>
GetOutputShape
()
{
return
GetTensorShape
(
output_
);
}
std
::
vector
<
int
>
GetOutputShape
()
{
return
GetTensorShape
(
output_
);
}
...
...
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.cc
View file @
20cc2190
...
@@ -309,7 +309,7 @@ void TensorflowTfLiteOpTest::RunTfLiteOp() {
...
@@ -309,7 +309,7 @@ void TensorflowTfLiteOpTest::RunTfLiteOp() {
input_index
++
;
input_index
++
;
}
}
tflite_op_
.
Invoke
();
ASSERT_EQ
(
tflite_op_
.
Invoke
()
,
kTfLiteOk
)
;
}
}
void
TensorflowTfLiteOpTest
::
CompareOpOutput
()
{
void
TensorflowTfLiteOpTest
::
CompareOpOutput
()
{
...
...
research/seq_flow_lite/tflite_ops/tf_tflite_diff_test_util.h
View file @
20cc2190
...
@@ -14,8 +14,8 @@ limitations under the License.
...
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
==============================================================================*/
// Tests equivalence between TF and TFLite versions of an op.
// Tests equivalence between TF and TFLite versions of an op.
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -146,4 +146,4 @@ class TensorflowTfLiteOpTest
...
@@ -146,4 +146,4 @@ class TensorflowTfLiteOpTest
}
// namespace testing
}
// namespace testing
}
// namespace seq_flow_lite
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
research/seq_flow_lite/tflite_ops/tflite_decoder_cache.h
View file @
20cc2190
...
@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
...
@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#ifndef
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#define
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#include <memory>
#include <memory>
#include "third_party/tensorflow/lite/c/common.h"
#include "tensorflow/lite/c/common.h"
namespace
seq_flow_lite
{
namespace
seq_flow_lite
{
namespace
ops
{
namespace
ops
{
namespace
custom
{
namespace
custom
{
...
@@ -113,4 +114,4 @@ class DynamicCacheOp {
...
@@ -113,4 +114,4 @@ class DynamicCacheOp {
}
// namespace custom
}
// namespace custom
}
// namespace ops
}
// namespace ops
}
// namespace seq_flow_lite
}
// namespace seq_flow_lite
#endif //
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.cc
View file @
20cc2190
...
@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
...
@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#include "
third_party/tensorflow_models/seq_flow_lite/
tflite_ops/tflite_decoder_handler.h"
#include "tflite_ops/tflite_decoder_handler.h"
// seq_flow_lite
#include <cstdint>
#include <cstdint>
#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_cache.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include "tflite_ops/tflite_decoder_cache.h" // seq_flow_lite
namespace
seq_flow_lite
{
namespace
seq_flow_lite
{
namespace
ops
{
namespace
ops
{
...
...
research/seq_flow_lite/tflite_ops/tflite_decoder_handler.h
View file @
20cc2190
...
@@ -13,18 +13,19 @@ See the License for the specific language governing permissions and
...
@@ -13,18 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#ifndef
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#define
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#include "
third_party/
tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register.h"
namespace
seq_flow_lite
{
namespace
seq_flow_lite
{
namespace
ops
{
namespace
ops
{
namespace
custom
{
namespace
custom
{
TfLiteRegistration
*
Register_UNIFORM_CAUSAL_ATTENTION
();
TfLiteRegistration
*
Register_UNIFORM_CAUSAL_ATTENTION
();
}
}
// namespace custom
}
// namespace ops
}
// namespace ops
}
// namespace seq_flow_lite
}
// namespace seq_flow_lite
#endif //
THIRD_PARTY_
TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
research/seq_flow_lite/tflite_ops/tflite_decoder_handler_test.cc
View file @
20cc2190
...
@@ -13,17 +13,17 @@ See the License for the specific language governing permissions and
...
@@ -13,17 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#include "
third_party/tensorflow_models/seq_flow_lite/
tflite_ops/tflite_decoder_handler.h"
#include "tflite_ops/tflite_decoder_handler.h"
// seq_flow_lite
#include <cstdint>
#include <cstdint>
#include <cstdlib>
#include <cstdlib>
#include <vector>
#include <vector>
#include
"testing/base/public
/gmock.h
"
#include
<gmock
/gmock.h
>
#include
"
test
ing/base/public/guni
t.h
"
#include
<g
test
/gtes
t.h
>
#include "
third_party/flatbuffers/include/flatbuffers/flex
buffer
s.h"
#include "
flatbuffers/flexbuffers.h" // flat
buffer
#include "
third_party/
tensorflow/lite/c/common.h"
#include "tensorflow/lite/c/common.h"
#include "
third_party/
tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/kernels/test_util.h"
namespace
{
namespace
{
...
...
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.cc
View file @
20cc2190
...
@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "tflite_ops/tflite_qrnn_pooling.h" // seq_flow_lite
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
namespace
seq_flow_lite
{
namespace
seq_flow_lite
{
namespace
ops
{
namespace
custom
{
namespace
{
namespace
{
...
@@ -126,9 +127,9 @@ TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
...
@@ -126,9 +127,9 @@ TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
return
QRNNPooling
(
context
,
multiplier
,
constant
,
outputs
,
final_state
,
return
QRNNPooling
(
context
,
multiplier
,
constant
,
outputs
,
final_state
,
(
direction
->
data
.
uint8
[
0
]
==
kPoolingForward
));
(
direction
->
data
.
uint8
[
0
]
==
kPoolingForward
));
}
}
}
// namespace
}
// namespace
namespace
custom
{
const
char
kPoolingOp
[]
=
"PoolingOp"
;
const
char
kPoolingOp
[]
=
"PoolingOp"
;
void
RegisterQRNNPooling
(
::
tflite
::
ops
::
builtin
::
BuiltinOpResolver
*
resolver
)
{
void
RegisterQRNNPooling
(
::
tflite
::
ops
::
builtin
::
BuiltinOpResolver
*
resolver
)
{
...
@@ -141,4 +142,5 @@ TfLiteRegistration* Register_QRNN_POOLING() {
...
@@ -141,4 +142,5 @@ TfLiteRegistration* Register_QRNN_POOLING() {
}
}
}
// namespace custom
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
}
// namespace seq_flow_lite
research/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h
View file @
20cc2190
...
@@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#ifndef TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#define TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#define TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#include "
third_party/
absl/base/macros.h"
#include "absl/base/macros.h"
#include "
third_party/
tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register.h"
namespace
seq_flow_lite
{
namespace
seq_flow_lite
{
namespace
ops
{
namespace
custom
{
namespace
custom
{
extern
const
char
kPoolingOp
[];
extern
const
char
kPoolingOp
[];
...
@@ -27,7 +27,7 @@ extern const char kPoolingOp[];
...
@@ -27,7 +27,7 @@ extern const char kPoolingOp[];
TfLiteRegistration
*
Register_QRNN_POOLING
();
TfLiteRegistration
*
Register_QRNN_POOLING
();
}
// namespace custom
}
// namespace custom
}
// namespace ops
}
// namespace seq_flow_lite
}
// namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQ
UENCE_PROJECTION
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#endif // TENSORFLOW_MODELS_SEQ
_FLOW_LITE
_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
research/seq_flow_lite/trainer.py
View file @
20cc2190
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
# Lint as: python3
"""A utility for PRADO model to do train, eval, inference and model export."""
"""A utility for PRADO model to do train, eval, inference and model export."""
import
importlib
import
importlib
...
@@ -22,6 +21,7 @@ from absl import app
...
@@ -22,6 +21,7 @@ from absl import app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
from
tensorflow.compat.v1
import
estimator
as
tf_estimator
import
input_fn_reader
# import root module
import
input_fn_reader
# import root module
import
metric_functions
# import root module
import
metric_functions
# import root module
...
@@ -48,14 +48,14 @@ def load_runner_config():
...
@@ -48,14 +48,14 @@ def load_runner_config():
return
json
.
loads
(
f
.
read
())
return
json
.
loads
(
f
.
read
())
def
create_model
(
model
,
model_config
,
features
,
mode
):
def
create_model
(
model
,
model_config
,
features
,
mode
,
model_name
):
"""Creates a sequence labeling model."""
"""Creates a sequence labeling model."""
keras_model
=
model
.
Encoder
(
model_config
,
mode
)
keras_model
=
model
.
Encoder
(
model_config
,
mode
)
if
"pqrnn"
in
model_name
:
if
"pqrnn"
in
model_name
:
logits
=
keras_model
(
features
[
"projection"
],
features
[
"seq_length"
])
logits
=
keras_model
(
features
[
"projection"
],
features
[
"seq_length"
])
else
:
else
:
logits
=
keras_model
(
features
[
"token_ids"
],
features
[
"token_len"
])
logits
=
keras_model
(
features
[
"token_ids"
],
features
[
"token_len"
])
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
if
mode
!=
tf
_
estimator
.
ModeKeys
.
PREDICT
:
if
not
model_config
[
"multilabel"
]:
if
not
model_config
[
"multilabel"
]:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
features
[
"label"
],
logits
=
logits
)
labels
=
features
[
"label"
],
logits
=
logits
)
...
@@ -94,33 +94,33 @@ def model_fn_builder(runner_config):
...
@@ -94,33 +94,33 @@ def model_fn_builder(runner_config):
def
model_fn
(
features
,
mode
,
params
):
def
model_fn
(
features
,
mode
,
params
):
"""The `model_fn` for TPUEstimator."""
"""The `model_fn` for TPUEstimator."""
label_ids
=
None
label_ids
=
None
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
if
mode
!=
tf
_
estimator
.
ModeKeys
.
PREDICT
:
label_ids
=
features
[
"label"
]
label_ids
=
features
[
"label"
]
model_config
=
runner_config
[
"model_config"
]
model_config
=
runner_config
[
"model_config"
]
loss
,
logits
=
create_model
(
model
,
model_config
,
features
,
mode
)
loss
,
logits
=
create_model
(
model
,
model_config
,
features
,
mode
,
runner_config
[
"name"
])
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
if
mode
==
tf
_
estimator
.
ModeKeys
.
TRAIN
:
train_op
=
create_optimizer
(
loss
,
runner_config
,
params
)
train_op
=
create_optimizer
(
loss
,
runner_config
,
params
)
return
tf
.
compat
.
v1
.
estimator
.
tpu
.
TPUEstimatorSpec
(
return
tf
_
estimator
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
train_op
=
train_op
)
mode
=
mode
,
loss
=
loss
,
train_op
=
train_op
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
elif
mode
==
tf
_
estimator
.
ModeKeys
.
EVAL
:
if
not
runner_config
[
"model_config"
][
"multilabel"
]:
if
not
runner_config
[
"model_config"
][
"multilabel"
]:
metric_fn
=
metric_functions
.
classification_metric
metric_fn
=
metric_functions
.
classification_metric
else
:
else
:
metric_fn
=
metric_functions
.
labeling_metric
metric_fn
=
metric_functions
.
labeling_metric
eval_metrics
=
(
metric_fn
,
[
loss
,
label_ids
,
logits
])
eval_metrics
=
(
metric_fn
,
[
loss
,
label_ids
,
logits
])
return
tf
.
compat
.
v1
.
estimator
.
tpu
.
TPUEstimatorSpec
(
return
tf
_
estimator
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
loss
,
eval_metrics
=
eval_metrics
)
mode
=
mode
,
loss
=
loss
,
eval_metrics
=
eval_metrics
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
elif
mode
==
tf
_
estimator
.
ModeKeys
.
PREDICT
:
predictions
=
{
"logits"
:
logits
}
predictions
=
{
"logits"
:
logits
}
if
not
runner_config
[
"model_config"
][
"multilabel"
]:
if
not
runner_config
[
"model_config"
][
"multilabel"
]:
predictions
[
"predictions"
]
=
tf
.
nn
.
softmax
(
logits
)
predictions
[
"predictions"
]
=
tf
.
nn
.
softmax
(
logits
)
else
:
else
:
predictions
[
"predictions"
]
=
tf
.
math
.
sigmoid
(
logits
)
predictions
[
"predictions"
]
=
tf
.
math
.
sigmoid
(
logits
)
return
tf
.
compat
.
v1
.
estimator
.
EstimatorSpec
(
return
tf_estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
)
mode
=
mode
,
predictions
=
predictions
)
else
:
else
:
assert
False
,
"Expected to be called in TRAIN, EVAL, or PREDICT mode."
assert
False
,
"Expected to be called in TRAIN, EVAL, or PREDICT mode."
...
@@ -133,13 +133,13 @@ def main(_):
...
@@ -133,13 +133,13 @@ def main(_):
if
FLAGS
.
output_dir
:
if
FLAGS
.
output_dir
:
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_dir
)
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_dir
)
is_per_host
=
tf
.
estimator
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
is_per_host
=
tf
_
estimator
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
run_config
=
tf
.
estimator
.
tpu
.
RunConfig
(
run_config
=
tf
_
estimator
.
tpu
.
RunConfig
(
master
=
FLAGS
.
master
,
master
=
FLAGS
.
master
,
model_dir
=
FLAGS
.
output_dir
,
model_dir
=
FLAGS
.
output_dir
,
save_checkpoints_steps
=
runner_config
[
"save_checkpoints_steps"
],
save_checkpoints_steps
=
runner_config
[
"save_checkpoints_steps"
],
keep_checkpoint_max
=
20
,
keep_checkpoint_max
=
20
,
tpu_config
=
tf
.
estimator
.
tpu
.
TPUConfig
(
tpu_config
=
tf
_
estimator
.
tpu
.
TPUConfig
(
iterations_per_loop
=
runner_config
[
"iterations_per_loop"
],
iterations_per_loop
=
runner_config
[
"iterations_per_loop"
],
num_shards
=
FLAGS
.
num_tpu_cores
,
num_shards
=
FLAGS
.
num_tpu_cores
,
per_host_input_for_training
=
is_per_host
))
per_host_input_for_training
=
is_per_host
))
...
@@ -149,7 +149,7 @@ def main(_):
...
@@ -149,7 +149,7 @@ def main(_):
# If TPU is not available, this will fall back to normal Estimator on CPU
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
# or GPU.
batch_size
=
runner_config
[
"batch_size"
]
batch_size
=
runner_config
[
"batch_size"
]
estimator
=
tf
.
estimator
.
tpu
.
TPUEstimator
(
estimator
=
tf
_
estimator
.
tpu
.
TPUEstimator
(
use_tpu
=
FLAGS
.
use_tpu
,
use_tpu
=
FLAGS
.
use_tpu
,
model_fn
=
model_fn
,
model_fn
=
model_fn
,
config
=
run_config
,
config
=
run_config
,
...
@@ -160,7 +160,7 @@ def main(_):
...
@@ -160,7 +160,7 @@ def main(_):
if
FLAGS
.
runner_mode
==
"train"
:
if
FLAGS
.
runner_mode
==
"train"
:
train_input_fn
=
input_fn_reader
.
create_input_fn
(
train_input_fn
=
input_fn_reader
.
create_input_fn
(
runner_config
=
runner_config
,
runner_config
=
runner_config
,
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
,
mode
=
tf
_
estimator
.
ModeKeys
.
TRAIN
,
drop_remainder
=
True
)
drop_remainder
=
True
)
estimator
.
train
(
estimator
.
train
(
input_fn
=
train_input_fn
,
max_steps
=
runner_config
[
"train_steps"
])
input_fn
=
train_input_fn
,
max_steps
=
runner_config
[
"train_steps"
])
...
@@ -168,7 +168,7 @@ def main(_):
...
@@ -168,7 +168,7 @@ def main(_):
# TPU needs fixed shapes, so if the last batch is smaller, we drop it.
# TPU needs fixed shapes, so if the last batch is smaller, we drop it.
eval_input_fn
=
input_fn_reader
.
create_input_fn
(
eval_input_fn
=
input_fn_reader
.
create_input_fn
(
runner_config
=
runner_config
,
runner_config
=
runner_config
,
mode
=
tf
.
estimator
.
ModeKeys
.
EVAL
,
mode
=
tf
_
estimator
.
ModeKeys
.
EVAL
,
drop_remainder
=
True
)
drop_remainder
=
True
)
for
_
in
tf
.
train
.
checkpoints_iterator
(
FLAGS
.
output_dir
,
timeout
=
600
):
for
_
in
tf
.
train
.
checkpoints_iterator
(
FLAGS
.
output_dir
,
timeout
=
600
):
...
...
research/seq_flow_lite/trainer_v2.py
View file @
20cc2190
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
# Lint as: python3
"""Binary to train PRADO model with TF 2.0."""
"""Binary to train PRADO model with TF 2.0."""
import
importlib
import
importlib
...
@@ -23,6 +22,7 @@ from absl import flags
...
@@ -23,6 +22,7 @@ from absl import flags
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow
import
estimator
as
tf_estimator
import
input_fn_reader
# import root module
import
input_fn_reader
# import root module
...
@@ -48,7 +48,7 @@ def load_runner_config():
...
@@ -48,7 +48,7 @@ def load_runner_config():
def
compute_loss
(
logits
,
labels
,
model_config
,
mode
):
def
compute_loss
(
logits
,
labels
,
model_config
,
mode
):
"""Creates a sequence labeling model."""
"""Creates a sequence labeling model."""
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
if
mode
!=
tf
_
estimator
.
ModeKeys
.
PREDICT
:
if
not
model_config
[
"multilabel"
]:
if
not
model_config
[
"multilabel"
]:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
labels
,
logits
=
logits
)
labels
=
labels
,
logits
=
logits
)
...
@@ -77,11 +77,11 @@ def main(_):
...
@@ -77,11 +77,11 @@ def main(_):
if
FLAGS
.
output_dir
:
if
FLAGS
.
output_dir
:
tf
.
io
.
gfile
.
makedirs
(
FLAGS
.
output_dir
)
tf
.
io
.
gfile
.
makedirs
(
FLAGS
.
output_dir
)
train_model
=
model_fn_builder
(
runner_config
,
tf
.
estimator
.
ModeKeys
.
TRAIN
)
train_model
=
model_fn_builder
(
runner_config
,
tf
_
estimator
.
ModeKeys
.
TRAIN
)
optimizer
=
tf
.
keras
.
optimizers
.
Adam
()
optimizer
=
tf
.
keras
.
optimizers
.
Adam
()
train_input_fn
=
input_fn_reader
.
create_input_fn
(
train_input_fn
=
input_fn_reader
.
create_input_fn
(
runner_config
=
runner_config
,
runner_config
=
runner_config
,
mode
=
tf
.
estimator
.
ModeKeys
.
TRAIN
,
mode
=
tf
_
estimator
.
ModeKeys
.
TRAIN
,
drop_remainder
=
True
)
drop_remainder
=
True
)
params
=
{
"batch_size"
:
runner_config
[
"batch_size"
]}
params
=
{
"batch_size"
:
runner_config
[
"batch_size"
]}
train_ds
=
train_input_fn
(
params
)
train_ds
=
train_input_fn
(
params
)
...
@@ -93,7 +93,7 @@ def main(_):
...
@@ -93,7 +93,7 @@ def main(_):
logits
=
train_model
(
features
[
"projection"
],
features
[
"seq_length"
])
logits
=
train_model
(
features
[
"projection"
],
features
[
"seq_length"
])
loss
=
compute_loss
(
logits
,
features
[
"label"
],
loss
=
compute_loss
(
logits
,
features
[
"label"
],
runner_config
[
"model_config"
],
runner_config
[
"model_config"
],
tf
.
estimator
.
ModeKeys
.
TRAIN
)
tf
_
estimator
.
ModeKeys
.
TRAIN
)
gradients
=
tape
.
gradient
(
loss
,
train_model
.
trainable_variables
)
gradients
=
tape
.
gradient
(
loss
,
train_model
.
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
gradients
,
train_model
.
trainable_variables
))
optimizer
.
apply_gradients
(
zip
(
gradients
,
train_model
.
trainable_variables
))
train_loss
(
loss
)
train_loss
(
loss
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment