Unverified Commit 20cc2190 authored by pyoung2778's avatar pyoung2778 Committed by GitHub
Browse files

Check in seq_flow_lite (#10750)

parent fdecf385
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/beam_search.h" #include "tflite_ops/beam_search.h" // seq_flow_lite
#include <algorithm> #include <algorithm>
#include <cstdint> #include <cstdint>
...@@ -21,10 +21,10 @@ limitations under the License. ...@@ -21,10 +21,10 @@ limitations under the License.
#include <vector> #include <vector>
#include "base/logging.h" #include "base/logging.h"
#include "third_party/absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/internal/types.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h" #include "tflite_ops/quantization_util.h" // seq_flow_lite
namespace seq_flow_lite { namespace seq_flow_lite {
namespace ops { namespace ops {
...@@ -86,6 +86,7 @@ void SequenceTracker::AddSequence(const int32_t *begin, const int32_t *end, ...@@ -86,6 +86,7 @@ void SequenceTracker::AddSequence(const int32_t *begin, const int32_t *end,
std::vector<std::vector<int32_t>> SequenceTracker::GetTopBeams() { std::vector<std::vector<int32_t>> SequenceTracker::GetTopBeams() {
std::vector<std::vector<int32_t>> return_value; std::vector<std::vector<int32_t>> return_value;
return_value.reserve(terminated_topk_.size());
for (const auto &v : terminated_topk_) { for (const auto &v : terminated_topk_) {
return_value.push_back(v.second); return_value.push_back(v.second);
} }
...@@ -255,8 +256,8 @@ void BeamSearch::FindTopKQuantizedFromLogitsV1(const TfLiteTensor &tensor, ...@@ -255,8 +256,8 @@ void BeamSearch::FindTopKQuantizedFromLogitsV1(const TfLiteTensor &tensor,
} }
} }
// Updating topk across all beams. // Updating topk across all beams.
for (uint32_t k = 0; k < std::min(topk_k, num_classes_); ++k) { for (uint32_t curr_beam : curr_beam_topk) {
const uint32_t curr_beam_index = curr_beam_topk[k] & kClassIndexMask; const uint32_t curr_beam_index = curr_beam & kClassIndexMask;
const uint32_t index = j * num_classes_ + curr_beam_index; const uint32_t index = j * num_classes_ + curr_beam_index;
const float log_prob = const float log_prob =
tensor.params.scale * beam_logits[curr_beam_index] - precomputed; tensor.params.scale * beam_logits[curr_beam_index] - precomputed;
......
...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_ #ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#define THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_ #define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
...@@ -23,7 +23,7 @@ limitations under the License. ...@@ -23,7 +23,7 @@ limitations under the License.
#include <set> #include <set>
#include <vector> #include <vector>
#include "third_party/tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
namespace seq_flow_lite { namespace seq_flow_lite {
namespace ops { namespace ops {
...@@ -110,4 +110,4 @@ class BeamSearch { ...@@ -110,4 +110,4 @@ class BeamSearch {
} // namespace custom } // namespace custom
} // namespace ops } // namespace ops
} // namespace seq_flow_lite } // namespace seq_flow_lite
#endif // THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_ #endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/beam_search.h" #include "tflite_ops/beam_search.h" // seq_flow_lite
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
...@@ -21,17 +21,17 @@ limitations under the License. ...@@ -21,17 +21,17 @@ limitations under the License.
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "testing/base/public/gmock.h" #include <gmock/gmock.h>
#include "testing/base/public/gunit.h" #include <gtest/gtest.h>
#include "third_party/absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "third_party/tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/c_api_types.h"
#include "third_party/tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "third_party/tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "third_party/tensorflow/lite/kernels/internal/reference/dequantize.h" #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "third_party/tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/internal/types.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h" #include "tflite_ops/quantization_util.h" // seq_flow_lite
namespace seq_flow_lite { namespace seq_flow_lite {
namespace ops { namespace ops {
...@@ -76,7 +76,7 @@ class BeamSearchImpl : public BeamSearch { ...@@ -76,7 +76,7 @@ class BeamSearchImpl : public BeamSearch {
cur_cache + (selected_beams[beam] * NumClasses()); cur_cache + (selected_beams[beam] * NumClasses());
for (int j = 0; j < NumClasses(); ++j, index++) { for (int j = 0; j < NumClasses(); ++j, index++) {
next_cache[index] = (selected[j] + next_cache[index]) / 2; next_cache[index] = (selected[j] + next_cache[index]) / 2;
data_ptr[index] = ::seq_flow_lite::PodQuantize( data_ptr[index] = PodQuantize(
next_cache[index], decoder_output_->params.zero_point, next_cache[index], decoder_output_->params.zero_point,
1.0f / decoder_output_->params.scale); 1.0f / decoder_output_->params.scale);
} }
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_EXPECTED_VALUE_H_ #ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_EXPECTED_VALUE_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_EXPECTED_VALUE_H_ #define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_EXPECTED_VALUE_H_
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
...@@ -27,4 +27,4 @@ TfLiteRegistration* Register_EXPECTED_VALUE(); ...@@ -27,4 +27,4 @@ TfLiteRegistration* Register_EXPECTED_VALUE();
} // namespace ops } // namespace ops
} // namespace seq_flow_lite } // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_EXPECTED_VALUE_H_ #endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_EXPECTED_VALUE_H_
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_ #ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_
#define LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_ #define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
...@@ -27,4 +27,4 @@ TfLiteRegistration* Register_LAYER_NORM(); ...@@ -27,4 +27,4 @@ TfLiteRegistration* Register_LAYER_NORM();
} // namespace ops } // namespace ops
} // namespace seq_flow_lite } // namespace seq_flow_lite
#endif // LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_ #endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_
...@@ -87,7 +87,7 @@ TEST(LayerNormModelTest, RegularInput) { ...@@ -87,7 +87,7 @@ TEST(LayerNormModelTest, RegularInput) {
/*input_max=*/10, /*output_min=*/-10, /*output_max=*/10, /*input_max=*/10, /*output_min=*/-10, /*output_max=*/10,
/*scale=*/1.0, /*offset=*/0.0, /*axes=*/{2}); /*scale=*/1.0, /*offset=*/0.0, /*axes=*/{2});
m.SetInput(input); m.SetInput(input);
m.Invoke(); ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT( EXPECT_THAT(
m.GetDequantizedOutput(), m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance))); ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
...@@ -106,7 +106,7 @@ TEST(LayerNormModelTest, NegativeScale) { ...@@ -106,7 +106,7 @@ TEST(LayerNormModelTest, NegativeScale) {
/*input_max=*/10, /*output_min=*/-10, /*output_max=*/10, /*input_max=*/10, /*output_min=*/-10, /*output_max=*/10,
/*scale=*/-1.0, /*offset=*/0.0, /*axes=*/{2}); /*scale=*/-1.0, /*offset=*/0.0, /*axes=*/{2});
m.SetInput(input); m.SetInput(input);
m.Invoke(); ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT( EXPECT_THAT(
m.GetDequantizedOutput(), m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance))); ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
...@@ -125,7 +125,7 @@ TEST(LayerNormModelTest, NegativeOffset) { ...@@ -125,7 +125,7 @@ TEST(LayerNormModelTest, NegativeOffset) {
/*input_max=*/10, /*output_min=*/-10, /*output_max=*/10, /*input_max=*/10, /*output_min=*/-10, /*output_max=*/10,
/*scale=*/1.0, /*offset=*/-1.0, /*axes=*/{2}); /*scale=*/1.0, /*offset=*/-1.0, /*axes=*/{2});
m.SetInput(input); m.SetInput(input);
m.Invoke(); ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT( EXPECT_THAT(
m.GetDequantizedOutput(), m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance))); ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
...@@ -144,7 +144,7 @@ TEST(LayerNormModelTest, NegativeScaleAndOffset) { ...@@ -144,7 +144,7 @@ TEST(LayerNormModelTest, NegativeScaleAndOffset) {
/*input_max=*/10, /*output_min=*/-10, /*output_max=*/10, /*input_max=*/10, /*output_min=*/-10, /*output_max=*/10,
/*scale=*/-1.0, /*offset=*/-1.0, /*axes=*/{2}); /*scale=*/-1.0, /*offset=*/-1.0, /*axes=*/{2});
m.SetInput(input); m.SetInput(input);
m.Invoke(); ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT( EXPECT_THAT(
m.GetDequantizedOutput(), m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance))); ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
...@@ -163,7 +163,7 @@ TEST(LayerNormModelTest, MultipleAxis) { ...@@ -163,7 +163,7 @@ TEST(LayerNormModelTest, MultipleAxis) {
/*input_max=*/3, /*output_min=*/-3, /*output_max=*/3, /*input_max=*/3, /*output_min=*/-3, /*output_max=*/3,
/*scale=*/1.0, /*offset=*/0.0, /*axes=*/{1, 3}); /*scale=*/1.0, /*offset=*/0.0, /*axes=*/{1, 3});
m.SetInput(input); m.SetInput(input);
m.Invoke(); ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT( EXPECT_THAT(
m.GetDequantizedOutput(), m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance))); ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
...@@ -182,7 +182,7 @@ TEST(LayerNormModelTest, MultipleNegativeAxis) { ...@@ -182,7 +182,7 @@ TEST(LayerNormModelTest, MultipleNegativeAxis) {
/*input_max=*/3, /*output_min=*/-3, /*output_max=*/3, /*input_max=*/3, /*output_min=*/-3, /*output_max=*/3,
/*scale=*/1.0, /*offset=*/0.0, /*axes=*/{-3, -1}); /*scale=*/1.0, /*offset=*/0.0, /*axes=*/{-3, -1});
m.SetInput(input); m.SetInput(input);
m.Invoke(); ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT( EXPECT_THAT(
m.GetDequantizedOutput(), m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance))); ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
...@@ -204,7 +204,7 @@ TEST(LayerNormModelTest, MultipleAxisWithLargeDepth) { ...@@ -204,7 +204,7 @@ TEST(LayerNormModelTest, MultipleAxisWithLargeDepth) {
/*input_max=*/1.0, /*output_min=*/-3.0, /*output_max=*/3.0, /*input_max=*/1.0, /*output_min=*/-3.0, /*output_max=*/3.0,
/*scale=*/1.0, /*offset=*/0.0, /*axes=*/{1, 3}); /*scale=*/1.0, /*offset=*/0.0, /*axes=*/{1, 3});
m.SetInput(input); m.SetInput(input);
m.Invoke(); ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT( EXPECT_THAT(
m.GetDequantizedOutput(), m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance))); ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_ #ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_ #define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
...@@ -50,4 +50,4 @@ inline uint8_t PodQuantize(float value, int32_t zero_point, ...@@ -50,4 +50,4 @@ inline uint8_t PodQuantize(float value, int32_t zero_point,
} // namespace seq_flow_lite } // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_ #endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_
...@@ -101,7 +101,7 @@ class ProjectionParams { ...@@ -101,7 +101,7 @@ class ProjectionParams {
bool exclude_nonalphaspace_unicodes, bool exclude_nonalphaspace_unicodes,
const std::string& token_separators, const std::string& token_separators,
bool normalize_repetition, bool add_first_cap_feature, bool normalize_repetition, bool add_first_cap_feature,
bool add_all_caps_feature) bool add_all_caps_feature, bool normalize_spaces)
: feature_size_(feature_size), : feature_size_(feature_size),
unicode_handler_(vocabulary, exclude_nonalphaspace_unicodes), unicode_handler_(vocabulary, exclude_nonalphaspace_unicodes),
hasher_(Hasher::CreateHasher(feature_size, hashtype)), hasher_(Hasher::CreateHasher(feature_size, hashtype)),
...@@ -130,9 +130,9 @@ class ProjectionParams { ...@@ -130,9 +130,9 @@ class ProjectionParams {
} }
word_novelty_offset_ = 2.0f / (1 << word_novelty_bits_); word_novelty_offset_ = 2.0f / (1 << word_novelty_bits_);
if (!token_separators.empty() || normalize_repetition) { if (!token_separators.empty() || normalize_repetition || normalize_spaces) {
projection_normalizer_ = std::make_unique<ProjectionNormalizer>( projection_normalizer_ = std::make_unique<ProjectionNormalizer>(
token_separators, normalize_repetition); token_separators, normalize_repetition, normalize_spaces);
} }
} }
virtual ~ProjectionParams() {} virtual ~ProjectionParams() {}
...@@ -242,7 +242,8 @@ class ProjectionParamsV2 : public ProjectionParams { ...@@ -242,7 +242,8 @@ class ProjectionParamsV2 : public ProjectionParams {
/*exclude_nonalphaspace_unicodes = */ false, /*exclude_nonalphaspace_unicodes = */ false,
/*token_separators = */ "", normalize_repetition, /*token_separators = */ "", normalize_repetition,
/*add_first_cap_feature = */ false, /*add_first_cap_feature = */ false,
/*add_all_caps_feature = */ false) {} /*add_all_caps_feature = */ false,
/*normalize_spaces = */ false) {}
~ProjectionParamsV2() override {} ~ProjectionParamsV2() override {}
TfLiteStatus PreprocessInput(TfLiteTensor* input_t, TfLiteStatus PreprocessInput(TfLiteTensor* input_t,
...@@ -341,6 +342,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { ...@@ -341,6 +342,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const std::string token_separators = const std::string token_separators =
m["token_separators"].IsNull() ? "" : m["token_separators"].ToString(); m["token_separators"].IsNull() ? "" : m["token_separators"].ToString();
const bool normalize_repetition = m["normalize_repetition"].AsBool(); const bool normalize_repetition = m["normalize_repetition"].AsBool();
const bool normalize_spaces = m["normalize_spaces"].AsBool();
if (!Hasher::SupportedHashType(hashtype)) { if (!Hasher::SupportedHashType(hashtype)) {
context->ReportError(context, "Unsupported hashtype %s\n", context->ReportError(context, "Unsupported hashtype %s\n",
hashtype.c_str()); hashtype.c_str());
...@@ -354,7 +356,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { ...@@ -354,7 +356,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
add_bos_tag ? BosTag::kGenerate : BosTag::kNone, add_bos_tag ? BosTag::kGenerate : BosTag::kNone,
add_eos_tag ? EosTag::kGenerate : EosTag::kNone, add_eos_tag ? EosTag::kGenerate : EosTag::kNone,
exclude_nonalphaspace_unicodes, token_separators, normalize_repetition, exclude_nonalphaspace_unicodes, token_separators, normalize_repetition,
add_first_cap_feature == 1.0f, add_all_caps_feature == 1.0f); add_first_cap_feature == 1.0f, add_all_caps_feature == 1.0f,
normalize_spaces);
} }
void* InitV2(TfLiteContext* context, const char* buffer, size_t length) { void* InitV2(TfLiteContext* context, const char* buffer, size_t length) {
......
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_ #ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_ #define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
namespace seq_flow_lite { namespace seq_flow_lite {
...@@ -27,8 +27,9 @@ TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION(); ...@@ -27,8 +27,9 @@ TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION();
extern const char kSequenceStringProjectionV2[]; extern const char kSequenceStringProjectionV2[];
TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION_V2(); TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION_V2();
} // namespace custom } // namespace custom
} // namespace ops } // namespace ops
} // namespace seq_flow_lite } // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_ #endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
...@@ -39,6 +39,7 @@ using ::seq_flow_lite::testing::OpEquivTestCase; ...@@ -39,6 +39,7 @@ using ::seq_flow_lite::testing::OpEquivTestCase;
using ::seq_flow_lite::testing::StringTensor; using ::seq_flow_lite::testing::StringTensor;
using ::seq_flow_lite::testing::TensorflowTfLiteOpTest; using ::seq_flow_lite::testing::TensorflowTfLiteOpTest;
using ::testing::ElementsAreArray; using ::testing::ElementsAreArray;
using ::testing::Not;
using ::tflite::TensorType_FLOAT32; using ::tflite::TensorType_FLOAT32;
using ::tflite::TensorType_STRING; using ::tflite::TensorType_STRING;
using ::tflite::TensorType_UINT8; using ::tflite::TensorType_UINT8;
...@@ -50,7 +51,8 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel { ...@@ -50,7 +51,8 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
int doc_size_levels, bool add_eos_tag, ::tflite::TensorType output_type, int doc_size_levels, bool add_eos_tag, ::tflite::TensorType output_type,
const std::string& token_separators = "", const std::string& token_separators = "",
bool normalize_repetition = false, float add_first_cap = 0.0, bool normalize_repetition = false, float add_first_cap = 0.0,
float add_all_caps = 0.0, const std::string& hashtype = kMurmurHash) { float add_all_caps = 0.0, const std::string& hashtype = kMurmurHash,
bool normalize_spaces = false) {
flexbuffers::Builder fbb; flexbuffers::Builder fbb;
fbb.Map([&] { fbb.Map([&] {
fbb.Int("feature_size", 4); fbb.Int("feature_size", 4);
...@@ -65,6 +67,7 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel { ...@@ -65,6 +67,7 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
fbb.Bool("normalize_repetition", normalize_repetition); fbb.Bool("normalize_repetition", normalize_repetition);
fbb.Float("add_first_cap_feature", add_first_cap); fbb.Float("add_first_cap_feature", add_first_cap);
fbb.Float("add_all_caps_feature", add_all_caps); fbb.Float("add_all_caps_feature", add_all_caps);
fbb.Bool("normalize_spaces", normalize_spaces);
}); });
fbb.Finish(); fbb.Finish();
output_ = AddOutput({output_type, {}}); output_ = AddOutput({output_type, {}});
...@@ -76,13 +79,13 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel { ...@@ -76,13 +79,13 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
PopulateStringTensor(input_, {input}); PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk) CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors"; << "Cannot allocate tensors";
SingleOpModel::Invoke(); CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
} }
TfLiteStatus InvokeFailable(const std::string& input) { TfLiteStatus InvokeFailable(const std::string& input) {
PopulateStringTensor(input_, {input}); PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk) CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors"; << "Cannot allocate tensors";
return SingleOpModel::InvokeUnchecked(); return SingleOpModel::Invoke();
} }
template <typename T> template <typename T>
...@@ -335,6 +338,32 @@ TEST(SequenceStringProjectionTest, NormalizeRepetition) { ...@@ -335,6 +338,32 @@ TEST(SequenceStringProjectionTest, NormalizeRepetition) {
EXPECT_THAT(output1, ElementsAreArray(output2)); EXPECT_THAT(output1, ElementsAreArray(output2));
} }
TEST(SequenceStringProjectionTest, NormalizeSpaces) {
SequenceStringProjectionModel model_nonormalize(false, -1, 0, 0, false,
TensorType_UINT8, "", false,
0.0, 0.0, kMurmurHash, false);
SequenceStringProjectionModel model_normalize(false, -1, 0, 0, false,
TensorType_UINT8, "", false,
0.0, 0.0, kMurmurHash, true);
const char kNoExtraSpaces[] = "Hello there.";
const char kExtraSpaces[] = " Hello there. ";
model_nonormalize.Invoke(kNoExtraSpaces);
auto output_noextra_nonorm = model_nonormalize.GetOutput<uint8_t>();
model_nonormalize.Invoke(kExtraSpaces);
auto output_extra_nonorm = model_nonormalize.GetOutput<uint8_t>();
model_normalize.Invoke(kNoExtraSpaces);
auto output_noextra_norm = model_normalize.GetOutput<uint8_t>();
model_normalize.Invoke(kExtraSpaces);
auto output_extra_norm = model_normalize.GetOutput<uint8_t>();
EXPECT_THAT(output_noextra_nonorm, ElementsAreArray(output_noextra_norm));
EXPECT_THAT(output_noextra_nonorm, ElementsAreArray(output_extra_norm));
EXPECT_THAT(output_noextra_nonorm,
Not(ElementsAreArray(output_extra_nonorm)));
}
class SequenceStringProjectionTest : public TensorflowTfLiteOpTest { class SequenceStringProjectionTest : public TensorflowTfLiteOpTest {
std::function<TfLiteRegistration*()> TfLiteOpRegistration() override { std::function<TfLiteRegistration*()> TfLiteOpRegistration() override {
return ops::custom::Register_SEQUENCE_STRING_PROJECTION; return ops::custom::Register_SEQUENCE_STRING_PROJECTION;
...@@ -710,6 +739,7 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() { ...@@ -710,6 +739,7 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero); test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case); test_cases.push_back(test_case);
} }
{ {
OpEquivTestCase test_case; OpEquivTestCase test_case;
test_case.test_name = "NormalizeRepetition"; test_case.test_name = "NormalizeRepetition";
...@@ -794,6 +824,20 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() { ...@@ -794,6 +824,20 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
test_cases.push_back(test_case); test_cases.push_back(test_case);
} }
{
OpEquivTestCase test_case;
test_case.test_name = "NormalizeSpaces";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(false);
test_case.attributes["add_bos_tag"] = AttrValue(false);
test_case.attributes["normalize_spaces"] = AttrValue(true);
test_case.input_tensors.push_back(StringTensor({1}, {" Hello there. "}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
return test_cases; return test_cases;
} }
...@@ -822,13 +866,13 @@ class SequenceStringProjectionV2Model : public ::tflite::SingleOpModel { ...@@ -822,13 +866,13 @@ class SequenceStringProjectionV2Model : public ::tflite::SingleOpModel {
PopulateStringTensor(input_, input); PopulateStringTensor(input_, input);
CHECK(interpreter_->AllocateTensors() == kTfLiteOk) CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors"; << "Cannot allocate tensors";
ASSERT_EQ(SingleOpModel::InvokeUnchecked(), expected); ASSERT_EQ(SingleOpModel::Invoke(), expected);
} }
TfLiteStatus InvokeFailable(const std::string& input) { TfLiteStatus InvokeFailable(const std::string& input) {
PopulateStringTensor(input_, {input}); PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk) CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors"; << "Cannot allocate tensors";
return SingleOpModel::InvokeUnchecked(); return SingleOpModel::Invoke();
} }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); } std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
......
...@@ -309,7 +309,7 @@ void TensorflowTfLiteOpTest::RunTfLiteOp() { ...@@ -309,7 +309,7 @@ void TensorflowTfLiteOpTest::RunTfLiteOp() {
input_index++; input_index++;
} }
tflite_op_.Invoke(); ASSERT_EQ(tflite_op_.Invoke(), kTfLiteOk);
} }
void TensorflowTfLiteOpTest::CompareOpOutput() { void TensorflowTfLiteOpTest::CompareOpOutput() {
......
...@@ -14,8 +14,8 @@ limitations under the License. ...@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
// Tests equivalence between TF and TFLite versions of an op. // Tests equivalence between TF and TFLite versions of an op.
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_ #ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_ #define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -146,4 +146,4 @@ class TensorflowTfLiteOpTest ...@@ -146,4 +146,4 @@ class TensorflowTfLiteOpTest
} // namespace testing } // namespace testing
} // namespace seq_flow_lite } // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_ #endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
...@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_ #ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#define THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_ #define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#include <memory> #include <memory>
#include "third_party/tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
namespace seq_flow_lite { namespace seq_flow_lite {
namespace ops { namespace ops {
namespace custom { namespace custom {
...@@ -113,4 +114,4 @@ class DynamicCacheOp { ...@@ -113,4 +114,4 @@ class DynamicCacheOp {
} // namespace custom } // namespace custom
} // namespace ops } // namespace ops
} // namespace seq_flow_lite } // namespace seq_flow_lite
#endif // THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_ #endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
...@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and ...@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_handler.h" #include "tflite_ops/tflite_decoder_handler.h" // seq_flow_lite
#include <cstdint> #include <cstdint>
#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h" #include "flatbuffers/flexbuffers.h" // flatbuffer
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_cache.h" #include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h" #include "tflite_ops/tflite_decoder_cache.h" // seq_flow_lite
namespace seq_flow_lite { namespace seq_flow_lite {
namespace ops { namespace ops {
......
...@@ -13,18 +13,19 @@ See the License for the specific language governing permissions and ...@@ -13,18 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_ #ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#define THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_ #define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#include "third_party/tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
namespace seq_flow_lite { namespace seq_flow_lite {
namespace ops { namespace ops {
namespace custom { namespace custom {
TfLiteRegistration* Register_UNIFORM_CAUSAL_ATTENTION(); TfLiteRegistration* Register_UNIFORM_CAUSAL_ATTENTION();
}
} // namespace custom
} // namespace ops } // namespace ops
} // namespace seq_flow_lite } // namespace seq_flow_lite
#endif // THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_ #endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
...@@ -13,17 +13,17 @@ See the License for the specific language governing permissions and ...@@ -13,17 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_handler.h" #include "tflite_ops/tflite_decoder_handler.h" // seq_flow_lite
#include <cstdint> #include <cstdint>
#include <cstdlib> #include <cstdlib>
#include <vector> #include <vector>
#include "testing/base/public/gmock.h" #include <gmock/gmock.h>
#include "testing/base/public/gunit.h" #include <gtest/gtest.h>
#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h" #include "flatbuffers/flexbuffers.h" // flatbuffer
#include "third_party/tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "third_party/tensorflow/lite/kernels/test_util.h" #include "tensorflow/lite/kernels/test_util.h"
namespace { namespace {
......
...@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h" #include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "tflite_ops/tflite_qrnn_pooling.h" // seq_flow_lite
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
namespace seq_flow_lite { namespace seq_flow_lite {
namespace ops {
namespace custom {
namespace { namespace {
...@@ -126,9 +127,9 @@ TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) { ...@@ -126,9 +127,9 @@ TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
return QRNNPooling(context, multiplier, constant, outputs, final_state, return QRNNPooling(context, multiplier, constant, outputs, final_state,
(direction->data.uint8[0] == kPoolingForward)); (direction->data.uint8[0] == kPoolingForward));
} }
} // namespace } // namespace
namespace custom {
const char kPoolingOp[] = "PoolingOp"; const char kPoolingOp[] = "PoolingOp";
void RegisterQRNNPooling(::tflite::ops::builtin::BuiltinOpResolver* resolver) { void RegisterQRNNPooling(::tflite::ops::builtin::BuiltinOpResolver* resolver) {
...@@ -141,4 +142,5 @@ TfLiteRegistration* Register_QRNN_POOLING() { ...@@ -141,4 +142,5 @@ TfLiteRegistration* Register_QRNN_POOLING() {
} }
} // namespace custom } // namespace custom
} // namespace ops
} // namespace seq_flow_lite } // namespace seq_flow_lite
...@@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TFLITE_QRNN_POOLING_H_ #ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TFLITE_QRNN_POOLING_H_ #define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#include "third_party/absl/base/macros.h" #include "absl/base/macros.h"
#include "third_party/tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
namespace seq_flow_lite { namespace seq_flow_lite {
namespace ops {
namespace custom { namespace custom {
extern const char kPoolingOp[]; extern const char kPoolingOp[];
...@@ -27,7 +27,7 @@ extern const char kPoolingOp[]; ...@@ -27,7 +27,7 @@ extern const char kPoolingOp[];
TfLiteRegistration* Register_QRNN_POOLING(); TfLiteRegistration* Register_QRNN_POOLING();
} // namespace custom } // namespace custom
} // namespace ops
} // namespace seq_flow_lite } // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TFLITE_QRNN_POOLING_H_ #endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
# Lint as: python3
"""A utility for PRADO model to do train, eval, inference and model export.""" """A utility for PRADO model to do train, eval, inference and model export."""
import importlib import importlib
...@@ -22,6 +21,7 @@ from absl import app ...@@ -22,6 +21,7 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
import input_fn_reader # import root module import input_fn_reader # import root module
import metric_functions # import root module import metric_functions # import root module
...@@ -48,14 +48,14 @@ def load_runner_config(): ...@@ -48,14 +48,14 @@ def load_runner_config():
return json.loads(f.read()) return json.loads(f.read())
def create_model(model, model_config, features, mode): def create_model(model, model_config, features, mode, model_name):
"""Creates a sequence labeling model.""" """Creates a sequence labeling model."""
keras_model = model.Encoder(model_config, mode) keras_model = model.Encoder(model_config, mode)
if "pqrnn" in model_name: if "pqrnn" in model_name:
logits = keras_model(features["projection"], features["seq_length"]) logits = keras_model(features["projection"], features["seq_length"])
else: else:
logits = keras_model(features["token_ids"], features["token_len"]) logits = keras_model(features["token_ids"], features["token_len"])
if mode != tf.estimator.ModeKeys.PREDICT: if mode != tf_estimator.ModeKeys.PREDICT:
if not model_config["multilabel"]: if not model_config["multilabel"]:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits( loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=features["label"], logits=logits) labels=features["label"], logits=logits)
...@@ -94,33 +94,33 @@ def model_fn_builder(runner_config): ...@@ -94,33 +94,33 @@ def model_fn_builder(runner_config):
def model_fn(features, mode, params): def model_fn(features, mode, params):
"""The `model_fn` for TPUEstimator.""" """The `model_fn` for TPUEstimator."""
label_ids = None label_ids = None
if mode != tf.estimator.ModeKeys.PREDICT: if mode != tf_estimator.ModeKeys.PREDICT:
label_ids = features["label"] label_ids = features["label"]
model_config = runner_config["model_config"] model_config = runner_config["model_config"]
loss, logits = create_model(model, model_config, features, mode) loss, logits = create_model(model, model_config, features, mode,
runner_config["name"])
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf_estimator.ModeKeys.TRAIN:
train_op = create_optimizer(loss, runner_config, params) train_op = create_optimizer(loss, runner_config, params)
return tf.compat.v1.estimator.tpu.TPUEstimatorSpec( return tf_estimator.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, train_op=train_op) mode=mode, loss=loss, train_op=train_op)
elif mode == tf.estimator.ModeKeys.EVAL: elif mode == tf_estimator.ModeKeys.EVAL:
if not runner_config["model_config"]["multilabel"]: if not runner_config["model_config"]["multilabel"]:
metric_fn = metric_functions.classification_metric metric_fn = metric_functions.classification_metric
else: else:
metric_fn = metric_functions.labeling_metric metric_fn = metric_functions.labeling_metric
eval_metrics = (metric_fn, [loss, label_ids, logits]) eval_metrics = (metric_fn, [loss, label_ids, logits])
return tf.compat.v1.estimator.tpu.TPUEstimatorSpec( return tf_estimator.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, eval_metrics=eval_metrics) mode=mode, loss=loss, eval_metrics=eval_metrics)
elif mode == tf.estimator.ModeKeys.PREDICT: elif mode == tf_estimator.ModeKeys.PREDICT:
predictions = {"logits": logits} predictions = {"logits": logits}
if not runner_config["model_config"]["multilabel"]: if not runner_config["model_config"]["multilabel"]:
predictions["predictions"] = tf.nn.softmax(logits) predictions["predictions"] = tf.nn.softmax(logits)
else: else:
predictions["predictions"] = tf.math.sigmoid(logits) predictions["predictions"] = tf.math.sigmoid(logits)
return tf.compat.v1.estimator.EstimatorSpec( return tf_estimator.EstimatorSpec(mode=mode, predictions=predictions)
mode=mode, predictions=predictions)
else: else:
assert False, "Expected to be called in TRAIN, EVAL, or PREDICT mode." assert False, "Expected to be called in TRAIN, EVAL, or PREDICT mode."
...@@ -133,13 +133,13 @@ def main(_): ...@@ -133,13 +133,13 @@ def main(_):
if FLAGS.output_dir: if FLAGS.output_dir:
tf.gfile.MakeDirs(FLAGS.output_dir) tf.gfile.MakeDirs(FLAGS.output_dir)
is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2 is_per_host = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.estimator.tpu.RunConfig( run_config = tf_estimator.tpu.RunConfig(
master=FLAGS.master, master=FLAGS.master,
model_dir=FLAGS.output_dir, model_dir=FLAGS.output_dir,
save_checkpoints_steps=runner_config["save_checkpoints_steps"], save_checkpoints_steps=runner_config["save_checkpoints_steps"],
keep_checkpoint_max=20, keep_checkpoint_max=20,
tpu_config=tf.estimator.tpu.TPUConfig( tpu_config=tf_estimator.tpu.TPUConfig(
iterations_per_loop=runner_config["iterations_per_loop"], iterations_per_loop=runner_config["iterations_per_loop"],
num_shards=FLAGS.num_tpu_cores, num_shards=FLAGS.num_tpu_cores,
per_host_input_for_training=is_per_host)) per_host_input_for_training=is_per_host))
...@@ -149,7 +149,7 @@ def main(_): ...@@ -149,7 +149,7 @@ def main(_):
# If TPU is not available, this will fall back to normal Estimator on CPU # If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU. # or GPU.
batch_size = runner_config["batch_size"] batch_size = runner_config["batch_size"]
estimator = tf.estimator.tpu.TPUEstimator( estimator = tf_estimator.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu, use_tpu=FLAGS.use_tpu,
model_fn=model_fn, model_fn=model_fn,
config=run_config, config=run_config,
...@@ -160,7 +160,7 @@ def main(_): ...@@ -160,7 +160,7 @@ def main(_):
if FLAGS.runner_mode == "train": if FLAGS.runner_mode == "train":
train_input_fn = input_fn_reader.create_input_fn( train_input_fn = input_fn_reader.create_input_fn(
runner_config=runner_config, runner_config=runner_config,
mode=tf.estimator.ModeKeys.TRAIN, mode=tf_estimator.ModeKeys.TRAIN,
drop_remainder=True) drop_remainder=True)
estimator.train( estimator.train(
input_fn=train_input_fn, max_steps=runner_config["train_steps"]) input_fn=train_input_fn, max_steps=runner_config["train_steps"])
...@@ -168,7 +168,7 @@ def main(_): ...@@ -168,7 +168,7 @@ def main(_):
# TPU needs fixed shapes, so if the last batch is smaller, we drop it. # TPU needs fixed shapes, so if the last batch is smaller, we drop it.
eval_input_fn = input_fn_reader.create_input_fn( eval_input_fn = input_fn_reader.create_input_fn(
runner_config=runner_config, runner_config=runner_config,
mode=tf.estimator.ModeKeys.EVAL, mode=tf_estimator.ModeKeys.EVAL,
drop_remainder=True) drop_remainder=True)
for _ in tf.train.checkpoints_iterator(FLAGS.output_dir, timeout=600): for _ in tf.train.checkpoints_iterator(FLAGS.output_dir, timeout=600):
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
# Lint as: python3
"""Binary to train PRADO model with TF 2.0.""" """Binary to train PRADO model with TF 2.0."""
import importlib import importlib
...@@ -23,6 +22,7 @@ from absl import flags ...@@ -23,6 +22,7 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow import estimator as tf_estimator
import input_fn_reader # import root module import input_fn_reader # import root module
...@@ -48,7 +48,7 @@ def load_runner_config(): ...@@ -48,7 +48,7 @@ def load_runner_config():
def compute_loss(logits, labels, model_config, mode): def compute_loss(logits, labels, model_config, mode):
"""Creates a sequence labeling model.""" """Creates a sequence labeling model."""
if mode != tf.estimator.ModeKeys.PREDICT: if mode != tf_estimator.ModeKeys.PREDICT:
if not model_config["multilabel"]: if not model_config["multilabel"]:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits( loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits) labels=labels, logits=logits)
...@@ -77,11 +77,11 @@ def main(_): ...@@ -77,11 +77,11 @@ def main(_):
if FLAGS.output_dir: if FLAGS.output_dir:
tf.io.gfile.makedirs(FLAGS.output_dir) tf.io.gfile.makedirs(FLAGS.output_dir)
train_model = model_fn_builder(runner_config, tf.estimator.ModeKeys.TRAIN) train_model = model_fn_builder(runner_config, tf_estimator.ModeKeys.TRAIN)
optimizer = tf.keras.optimizers.Adam() optimizer = tf.keras.optimizers.Adam()
train_input_fn = input_fn_reader.create_input_fn( train_input_fn = input_fn_reader.create_input_fn(
runner_config=runner_config, runner_config=runner_config,
mode=tf.estimator.ModeKeys.TRAIN, mode=tf_estimator.ModeKeys.TRAIN,
drop_remainder=True) drop_remainder=True)
params = {"batch_size": runner_config["batch_size"]} params = {"batch_size": runner_config["batch_size"]}
train_ds = train_input_fn(params) train_ds = train_input_fn(params)
...@@ -93,7 +93,7 @@ def main(_): ...@@ -93,7 +93,7 @@ def main(_):
logits = train_model(features["projection"], features["seq_length"]) logits = train_model(features["projection"], features["seq_length"])
loss = compute_loss(logits, features["label"], loss = compute_loss(logits, features["label"],
runner_config["model_config"], runner_config["model_config"],
tf.estimator.ModeKeys.TRAIN) tf_estimator.ModeKeys.TRAIN)
gradients = tape.gradient(loss, train_model.trainable_variables) gradients = tape.gradient(loss, train_model.trainable_variables)
optimizer.apply_gradients(zip(gradients, train_model.trainable_variables)) optimizer.apply_gradients(zip(gradients, train_model.trainable_variables))
train_loss(loss) train_loss(loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment