Unverified Commit 20cc2190 authored by pyoung2778's avatar pyoung2778 Committed by GitHub
Browse files

Check in seq_flow_lite (#10750)

parent fdecf385
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/beam_search.h"
#include "tflite_ops/beam_search.h" // seq_flow_lite
#include <algorithm>
#include <cstdint>
......@@ -21,10 +21,10 @@ limitations under the License.
#include <vector>
#include "base/logging.h"
#include "third_party/absl/strings/str_join.h"
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/internal/types.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include "absl/strings/str_join.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
namespace seq_flow_lite {
namespace ops {
......@@ -86,6 +86,7 @@ void SequenceTracker::AddSequence(const int32_t *begin, const int32_t *end,
std::vector<std::vector<int32_t>> SequenceTracker::GetTopBeams() {
std::vector<std::vector<int32_t>> return_value;
return_value.reserve(terminated_topk_.size());
for (const auto &v : terminated_topk_) {
return_value.push_back(v.second);
}
......@@ -255,8 +256,8 @@ void BeamSearch::FindTopKQuantizedFromLogitsV1(const TfLiteTensor &tensor,
}
}
// Updating topk across all beams.
for (uint32_t k = 0; k < std::min(topk_k, num_classes_); ++k) {
const uint32_t curr_beam_index = curr_beam_topk[k] & kClassIndexMask;
for (uint32_t curr_beam : curr_beam_topk) {
const uint32_t curr_beam_index = curr_beam & kClassIndexMask;
const uint32_t index = j * num_classes_ + curr_beam_index;
const float log_prob =
tensor.params.scale * beam_logits[curr_beam_index] - precomputed;
......
......@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#define THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#include <cstdint>
#include <functional>
......@@ -23,7 +23,7 @@ limitations under the License.
#include <set>
#include <vector>
#include "third_party/tensorflow/lite/c/common.h"
#include "tensorflow/lite/c/common.h"
namespace seq_flow_lite {
namespace ops {
......@@ -110,4 +110,4 @@ class BeamSearch {
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
#endif // THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_BEAM_SEARCH_H_
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/beam_search.h"
#include "tflite_ops/beam_search.h" // seq_flow_lite
#include <cstdint>
#include <functional>
......@@ -21,17 +21,17 @@ limitations under the License.
#include <memory>
#include <vector>
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
#include "third_party/absl/strings/str_join.h"
#include "third_party/tensorflow/lite/c/c_api_types.h"
#include "third_party/tensorflow/lite/c/common.h"
#include "third_party/tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "third_party/tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "third_party/tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/internal/types.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/str_join.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
namespace seq_flow_lite {
namespace ops {
......@@ -76,7 +76,7 @@ class BeamSearchImpl : public BeamSearch {
cur_cache + (selected_beams[beam] * NumClasses());
for (int j = 0; j < NumClasses(); ++j, index++) {
next_cache[index] = (selected[j] + next_cache[index]) / 2;
data_ptr[index] = ::seq_flow_lite::PodQuantize(
data_ptr[index] = PodQuantize(
next_cache[index], decoder_output_->params.zero_point,
1.0f / decoder_output_->params.scale);
}
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_EXPECTED_VALUE_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_EXPECTED_VALUE_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_EXPECTED_VALUE_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_EXPECTED_VALUE_H_
#include "tensorflow/lite/kernels/register.h"
......@@ -27,4 +27,4 @@ TfLiteRegistration* Register_EXPECTED_VALUE();
} // namespace ops
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_EXPECTED_VALUE_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_EXPECTED_VALUE_H_
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
#define LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_
#include "tensorflow/lite/kernels/register.h"
......@@ -27,4 +27,4 @@ TfLiteRegistration* Register_LAYER_NORM();
} // namespace ops
} // namespace seq_flow_lite
#endif // LEARNING_EXPANDER_POD_DEEP_POD_TFLITE_HANDLERS_LAYER_NORM_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_LAYER_NORM_H_
......@@ -87,7 +87,7 @@ TEST(LayerNormModelTest, RegularInput) {
/*input_max=*/10, /*output_min=*/-10, /*output_max=*/10,
/*scale=*/1.0, /*offset=*/0.0, /*axes=*/{2});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......@@ -106,7 +106,7 @@ TEST(LayerNormModelTest, NegativeScale) {
/*input_max=*/10, /*output_min=*/-10, /*output_max=*/10,
/*scale=*/-1.0, /*offset=*/0.0, /*axes=*/{2});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......@@ -125,7 +125,7 @@ TEST(LayerNormModelTest, NegativeOffset) {
/*input_max=*/10, /*output_min=*/-10, /*output_max=*/10,
/*scale=*/1.0, /*offset=*/-1.0, /*axes=*/{2});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......@@ -144,7 +144,7 @@ TEST(LayerNormModelTest, NegativeScaleAndOffset) {
/*input_max=*/10, /*output_min=*/-10, /*output_max=*/10,
/*scale=*/-1.0, /*offset=*/-1.0, /*axes=*/{2});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......@@ -163,7 +163,7 @@ TEST(LayerNormModelTest, MultipleAxis) {
/*input_max=*/3, /*output_min=*/-3, /*output_max=*/3,
/*scale=*/1.0, /*offset=*/0.0, /*axes=*/{1, 3});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......@@ -182,7 +182,7 @@ TEST(LayerNormModelTest, MultipleNegativeAxis) {
/*input_max=*/3, /*output_min=*/-3, /*output_max=*/3,
/*scale=*/1.0, /*offset=*/0.0, /*axes=*/{-3, -1});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......@@ -204,7 +204,7 @@ TEST(LayerNormModelTest, MultipleAxisWithLargeDepth) {
/*input_max=*/1.0, /*output_min=*/-3.0, /*output_max=*/3.0,
/*scale=*/1.0, /*offset=*/0.0, /*axes=*/{1, 3});
m.SetInput(input);
m.Invoke();
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance)));
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_
#include <algorithm>
#include <cmath>
......@@ -50,4 +50,4 @@ inline uint8_t PodQuantize(float value, int32_t zero_point,
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_QUANTIZATION_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_QUANTIZATION_UTIL_H_
......@@ -101,7 +101,7 @@ class ProjectionParams {
bool exclude_nonalphaspace_unicodes,
const std::string& token_separators,
bool normalize_repetition, bool add_first_cap_feature,
bool add_all_caps_feature)
bool add_all_caps_feature, bool normalize_spaces)
: feature_size_(feature_size),
unicode_handler_(vocabulary, exclude_nonalphaspace_unicodes),
hasher_(Hasher::CreateHasher(feature_size, hashtype)),
......@@ -130,9 +130,9 @@ class ProjectionParams {
}
word_novelty_offset_ = 2.0f / (1 << word_novelty_bits_);
if (!token_separators.empty() || normalize_repetition) {
if (!token_separators.empty() || normalize_repetition || normalize_spaces) {
projection_normalizer_ = std::make_unique<ProjectionNormalizer>(
token_separators, normalize_repetition);
token_separators, normalize_repetition, normalize_spaces);
}
}
virtual ~ProjectionParams() {}
......@@ -242,7 +242,8 @@ class ProjectionParamsV2 : public ProjectionParams {
/*exclude_nonalphaspace_unicodes = */ false,
/*token_separators = */ "", normalize_repetition,
/*add_first_cap_feature = */ false,
/*add_all_caps_feature = */ false) {}
/*add_all_caps_feature = */ false,
/*normalize_spaces = */ false) {}
~ProjectionParamsV2() override {}
TfLiteStatus PreprocessInput(TfLiteTensor* input_t,
......@@ -341,6 +342,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const std::string token_separators =
m["token_separators"].IsNull() ? "" : m["token_separators"].ToString();
const bool normalize_repetition = m["normalize_repetition"].AsBool();
const bool normalize_spaces = m["normalize_spaces"].AsBool();
if (!Hasher::SupportedHashType(hashtype)) {
context->ReportError(context, "Unsupported hashtype %s\n",
hashtype.c_str());
......@@ -354,7 +356,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
add_bos_tag ? BosTag::kGenerate : BosTag::kNone,
add_eos_tag ? EosTag::kGenerate : EosTag::kNone,
exclude_nonalphaspace_unicodes, token_separators, normalize_repetition,
add_first_cap_feature == 1.0f, add_all_caps_feature == 1.0f);
add_first_cap_feature == 1.0f, add_all_caps_feature == 1.0f,
normalize_spaces);
}
void* InitV2(TfLiteContext* context, const char* buffer, size_t length) {
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#include "tensorflow/lite/kernels/register.h"
namespace seq_flow_lite {
......@@ -27,8 +27,9 @@ TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION();
extern const char kSequenceStringProjectionV2[];
TfLiteRegistration* Register_SEQUENCE_STRING_PROJECTION_V2();
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_SEQUENCE_STRING_PROJECTION_H_
......@@ -39,6 +39,7 @@ using ::seq_flow_lite::testing::OpEquivTestCase;
using ::seq_flow_lite::testing::StringTensor;
using ::seq_flow_lite::testing::TensorflowTfLiteOpTest;
using ::testing::ElementsAreArray;
using ::testing::Not;
using ::tflite::TensorType_FLOAT32;
using ::tflite::TensorType_STRING;
using ::tflite::TensorType_UINT8;
......@@ -50,7 +51,8 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
int doc_size_levels, bool add_eos_tag, ::tflite::TensorType output_type,
const std::string& token_separators = "",
bool normalize_repetition = false, float add_first_cap = 0.0,
float add_all_caps = 0.0, const std::string& hashtype = kMurmurHash) {
float add_all_caps = 0.0, const std::string& hashtype = kMurmurHash,
bool normalize_spaces = false) {
flexbuffers::Builder fbb;
fbb.Map([&] {
fbb.Int("feature_size", 4);
......@@ -65,6 +67,7 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
fbb.Bool("normalize_repetition", normalize_repetition);
fbb.Float("add_first_cap_feature", add_first_cap);
fbb.Float("add_all_caps_feature", add_all_caps);
fbb.Bool("normalize_spaces", normalize_spaces);
});
fbb.Finish();
output_ = AddOutput({output_type, {}});
......@@ -76,13 +79,13 @@ class SequenceStringProjectionModel : public ::tflite::SingleOpModel {
PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
SingleOpModel::Invoke();
CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
}
TfLiteStatus InvokeFailable(const std::string& input) {
PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
return SingleOpModel::InvokeUnchecked();
return SingleOpModel::Invoke();
}
template <typename T>
......@@ -335,6 +338,32 @@ TEST(SequenceStringProjectionTest, NormalizeRepetition) {
EXPECT_THAT(output1, ElementsAreArray(output2));
}
TEST(SequenceStringProjectionTest, NormalizeSpaces) {
SequenceStringProjectionModel model_nonormalize(false, -1, 0, 0, false,
TensorType_UINT8, "", false,
0.0, 0.0, kMurmurHash, false);
SequenceStringProjectionModel model_normalize(false, -1, 0, 0, false,
TensorType_UINT8, "", false,
0.0, 0.0, kMurmurHash, true);
const char kNoExtraSpaces[] = "Hello there.";
const char kExtraSpaces[] = " Hello there. ";
model_nonormalize.Invoke(kNoExtraSpaces);
auto output_noextra_nonorm = model_nonormalize.GetOutput<uint8_t>();
model_nonormalize.Invoke(kExtraSpaces);
auto output_extra_nonorm = model_nonormalize.GetOutput<uint8_t>();
model_normalize.Invoke(kNoExtraSpaces);
auto output_noextra_norm = model_normalize.GetOutput<uint8_t>();
model_normalize.Invoke(kExtraSpaces);
auto output_extra_norm = model_normalize.GetOutput<uint8_t>();
EXPECT_THAT(output_noextra_nonorm, ElementsAreArray(output_noextra_norm));
EXPECT_THAT(output_noextra_nonorm, ElementsAreArray(output_extra_norm));
EXPECT_THAT(output_noextra_nonorm,
Not(ElementsAreArray(output_extra_nonorm)));
}
class SequenceStringProjectionTest : public TensorflowTfLiteOpTest {
std::function<TfLiteRegistration*()> TfLiteOpRegistration() override {
return ops::custom::Register_SEQUENCE_STRING_PROJECTION;
......@@ -710,6 +739,7 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NormalizeRepetition";
......@@ -794,6 +824,20 @@ std::vector<OpEquivTestCase> SequenceStringProjectionTestCases() {
test_cases.push_back(test_case);
}
{
OpEquivTestCase test_case;
test_case.test_name = "NormalizeSpaces";
test_case.attributes["vocabulary"] = AttrValue("");
test_case.attributes["split_on_space"] = AttrValue(true);
test_case.attributes["feature_size"] = AttrValue(8);
test_case.attributes["add_eos_tag"] = AttrValue(false);
test_case.attributes["add_bos_tag"] = AttrValue(false);
test_case.attributes["normalize_spaces"] = AttrValue(true);
test_case.input_tensors.push_back(StringTensor({1}, {" Hello there. "}));
test_case.output_tensors.emplace_back(FloatTensor({}, {}), kScale, kZero);
test_cases.push_back(test_case);
}
return test_cases;
}
......@@ -822,13 +866,13 @@ class SequenceStringProjectionV2Model : public ::tflite::SingleOpModel {
PopulateStringTensor(input_, input);
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
ASSERT_EQ(SingleOpModel::InvokeUnchecked(), expected);
ASSERT_EQ(SingleOpModel::Invoke(), expected);
}
TfLiteStatus InvokeFailable(const std::string& input) {
PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
return SingleOpModel::InvokeUnchecked();
return SingleOpModel::Invoke();
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
......
......@@ -309,7 +309,7 @@ void TensorflowTfLiteOpTest::RunTfLiteOp() {
input_index++;
}
tflite_op_.Invoke();
ASSERT_EQ(tflite_op_.Invoke(), kTfLiteOk);
}
void TensorflowTfLiteOpTest::CompareOpOutput() {
......
......@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// Tests equivalence between TF and TFLite versions of an op.
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#include <string>
#include <vector>
......@@ -146,4 +146,4 @@ class TensorflowTfLiteOpTest
} // namespace testing
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TF_TFLITE_DIFF_TEST_UTIL_H_
......@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#define THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#include <memory>
#include "third_party/tensorflow/lite/c/common.h"
#include "tensorflow/lite/c/common.h"
namespace seq_flow_lite {
namespace ops {
namespace custom {
......@@ -113,4 +114,4 @@ class DynamicCacheOp {
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
#endif // THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_CACHE_H_
......@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_handler.h"
#include "tflite_ops/tflite_decoder_handler.h" // seq_flow_lite
#include <cstdint>
#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
#include "third_party/tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "third_party/tensorflow/lite/kernels/kernel_util.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_cache.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "tflite_ops/tflite_decoder_cache.h" // seq_flow_lite
namespace seq_flow_lite {
namespace ops {
......
......@@ -13,18 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#define THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#include "third_party/tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register.h"
namespace seq_flow_lite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_UNIFORM_CAUSAL_ATTENTION();
}
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
#endif // THIRD_PARTY_TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_DECODER_HANDLER_H_
......@@ -13,17 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_decoder_handler.h"
#include "tflite_ops/tflite_decoder_handler.h" // seq_flow_lite
#include <cstdint>
#include <cstdlib>
#include <vector>
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h"
#include "third_party/tensorflow/lite/c/common.h"
#include "third_party/tensorflow/lite/kernels/test_util.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "flatbuffers/flexbuffers.h" // flatbuffer
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/test_util.h"
namespace {
......
......@@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/tflite_qrnn_pooling.h"
#include "third_party/tensorflow_models/seq_flow_lite/tflite_ops/quantization_util.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "tflite_ops/tflite_qrnn_pooling.h" // seq_flow_lite
namespace seq_flow_lite {
namespace ops {
namespace custom {
namespace {
......@@ -126,9 +127,9 @@ TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
return QRNNPooling(context, multiplier, constant, outputs, final_state,
(direction->data.uint8[0] == kPoolingForward));
}
} // namespace
namespace custom {
const char kPoolingOp[] = "PoolingOp";
void RegisterQRNNPooling(::tflite::ops::builtin::BuiltinOpResolver* resolver) {
......@@ -141,4 +142,5 @@ TfLiteRegistration* Register_QRNN_POOLING() {
}
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
......@@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#ifndef TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#define TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#include "third_party/absl/base/macros.h"
#include "third_party/tensorflow/lite/kernels/register.h"
#include "absl/base/macros.h"
#include "tensorflow/lite/kernels/register.h"
namespace seq_flow_lite {
namespace ops {
namespace custom {
extern const char kPoolingOp[];
......@@ -27,7 +27,7 @@ extern const char kPoolingOp[];
TfLiteRegistration* Register_QRNN_POOLING();
} // namespace custom
} // namespace ops
} // namespace seq_flow_lite
#endif // TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
#endif // TENSORFLOW_MODELS_SEQ_FLOW_LITE_TFLITE_OPS_TFLITE_QRNN_POOLING_H_
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""A utility for PRADO model to do train, eval, inference and model export."""
import importlib
......@@ -22,6 +21,7 @@ from absl import app
from absl import flags
from absl import logging
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
import input_fn_reader # import root module
import metric_functions # import root module
......@@ -48,14 +48,14 @@ def load_runner_config():
return json.loads(f.read())
def create_model(model, model_config, features, mode):
def create_model(model, model_config, features, mode, model_name):
"""Creates a sequence labeling model."""
keras_model = model.Encoder(model_config, mode)
if "pqrnn" in model_name:
logits = keras_model(features["projection"], features["seq_length"])
else:
logits = keras_model(features["token_ids"], features["token_len"])
if mode != tf.estimator.ModeKeys.PREDICT:
if mode != tf_estimator.ModeKeys.PREDICT:
if not model_config["multilabel"]:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=features["label"], logits=logits)
......@@ -94,33 +94,33 @@ def model_fn_builder(runner_config):
def model_fn(features, mode, params):
"""The `model_fn` for TPUEstimator."""
label_ids = None
if mode != tf.estimator.ModeKeys.PREDICT:
if mode != tf_estimator.ModeKeys.PREDICT:
label_ids = features["label"]
model_config = runner_config["model_config"]
loss, logits = create_model(model, model_config, features, mode)
loss, logits = create_model(model, model_config, features, mode,
runner_config["name"])
if mode == tf.estimator.ModeKeys.TRAIN:
if mode == tf_estimator.ModeKeys.TRAIN:
train_op = create_optimizer(loss, runner_config, params)
return tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
return tf_estimator.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, train_op=train_op)
elif mode == tf.estimator.ModeKeys.EVAL:
elif mode == tf_estimator.ModeKeys.EVAL:
if not runner_config["model_config"]["multilabel"]:
metric_fn = metric_functions.classification_metric
else:
metric_fn = metric_functions.labeling_metric
eval_metrics = (metric_fn, [loss, label_ids, logits])
return tf.compat.v1.estimator.tpu.TPUEstimatorSpec(
return tf_estimator.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, eval_metrics=eval_metrics)
elif mode == tf.estimator.ModeKeys.PREDICT:
elif mode == tf_estimator.ModeKeys.PREDICT:
predictions = {"logits": logits}
if not runner_config["model_config"]["multilabel"]:
predictions["predictions"] = tf.nn.softmax(logits)
else:
predictions["predictions"] = tf.math.sigmoid(logits)
return tf.compat.v1.estimator.EstimatorSpec(
mode=mode, predictions=predictions)
return tf_estimator.EstimatorSpec(mode=mode, predictions=predictions)
else:
assert False, "Expected to be called in TRAIN, EVAL, or PREDICT mode."
......@@ -133,13 +133,13 @@ def main(_):
if FLAGS.output_dir:
tf.gfile.MakeDirs(FLAGS.output_dir)
is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.estimator.tpu.RunConfig(
is_per_host = tf_estimator.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf_estimator.tpu.RunConfig(
master=FLAGS.master,
model_dir=FLAGS.output_dir,
save_checkpoints_steps=runner_config["save_checkpoints_steps"],
keep_checkpoint_max=20,
tpu_config=tf.estimator.tpu.TPUConfig(
tpu_config=tf_estimator.tpu.TPUConfig(
iterations_per_loop=runner_config["iterations_per_loop"],
num_shards=FLAGS.num_tpu_cores,
per_host_input_for_training=is_per_host))
......@@ -149,7 +149,7 @@ def main(_):
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
batch_size = runner_config["batch_size"]
estimator = tf.estimator.tpu.TPUEstimator(
estimator = tf_estimator.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=model_fn,
config=run_config,
......@@ -160,7 +160,7 @@ def main(_):
if FLAGS.runner_mode == "train":
train_input_fn = input_fn_reader.create_input_fn(
runner_config=runner_config,
mode=tf.estimator.ModeKeys.TRAIN,
mode=tf_estimator.ModeKeys.TRAIN,
drop_remainder=True)
estimator.train(
input_fn=train_input_fn, max_steps=runner_config["train_steps"])
......@@ -168,7 +168,7 @@ def main(_):
# TPU needs fixed shapes, so if the last batch is smaller, we drop it.
eval_input_fn = input_fn_reader.create_input_fn(
runner_config=runner_config,
mode=tf.estimator.ModeKeys.EVAL,
mode=tf_estimator.ModeKeys.EVAL,
drop_remainder=True)
for _ in tf.train.checkpoints_iterator(FLAGS.output_dir, timeout=600):
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Binary to train PRADO model with TF 2.0."""
import importlib
......@@ -23,6 +22,7 @@ from absl import flags
from absl import logging
import tensorflow as tf
from tensorflow import estimator as tf_estimator
import input_fn_reader # import root module
......@@ -48,7 +48,7 @@ def load_runner_config():
def compute_loss(logits, labels, model_config, mode):
"""Creates a sequence labeling model."""
if mode != tf.estimator.ModeKeys.PREDICT:
if mode != tf_estimator.ModeKeys.PREDICT:
if not model_config["multilabel"]:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
......@@ -77,11 +77,11 @@ def main(_):
if FLAGS.output_dir:
tf.io.gfile.makedirs(FLAGS.output_dir)
train_model = model_fn_builder(runner_config, tf.estimator.ModeKeys.TRAIN)
train_model = model_fn_builder(runner_config, tf_estimator.ModeKeys.TRAIN)
optimizer = tf.keras.optimizers.Adam()
train_input_fn = input_fn_reader.create_input_fn(
runner_config=runner_config,
mode=tf.estimator.ModeKeys.TRAIN,
mode=tf_estimator.ModeKeys.TRAIN,
drop_remainder=True)
params = {"batch_size": runner_config["batch_size"]}
train_ds = train_input_fn(params)
......@@ -93,7 +93,7 @@ def main(_):
logits = train_model(features["projection"], features["seq_length"])
loss = compute_loss(logits, features["label"],
runner_config["model_config"],
tf.estimator.ModeKeys.TRAIN)
tf_estimator.ModeKeys.TRAIN)
gradients = tape.gradient(loss, train_model.trainable_variables)
optimizer.apply_gradients(zip(gradients, train_model.trainable_variables))
train_loss(loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment