"...libhttp/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "4bf43ab9d4e6b724836724c1d514760d164df79a"
Unverified Commit 51f4ecad authored by prabhukaliamoorthi's avatar prabhukaliamoorthi Committed by GitHub
Browse files

Add demo app and update op handlers (#9503)

parent 7310b0f8
licenses(["notice"]) licenses(["notice"])
package( package(
default_visibility = [ default_visibility = [":friends"],
"//:__subpackages__", )
package_group(
name = "friends",
packages = [
"//...",
], ],
) )
......
...@@ -35,8 +35,8 @@ models computes them on the fly. ...@@ -35,8 +35,8 @@ models computes them on the fly.
Train a PRADO model on civil comments dataset Train a PRADO model on civil comments dataset
```shell ```shell
bazel run -c opt prado:trainer -- \ bazel run -c opt :trainer -- \
--config_path=$(pwd)/prado/civil_comments_prado.txt \ --config_path=$(pwd)/configs/civil_comments_prado.txt \
--runner_mode=train --logtostderr --output_dir=/tmp/prado --runner_mode=train --logtostderr --output_dir=/tmp/prado
``` ```
...@@ -51,9 +51,9 @@ bazel run -c opt sgnn:train -- --logtostderr --output_dir=/tmp/sgnn ...@@ -51,9 +51,9 @@ bazel run -c opt sgnn:train -- --logtostderr --output_dir=/tmp/sgnn
Evaluate PRADO model: Evaluate PRADO model:
```shell ```shell
bazel run -c opt prado:trainer -- \ bazel run -c opt :trainer -- \
--config_path=$(pwd)/prado/civil_comments_prado.txt \ --config_path=$(pwd)/configs/civil_comments_prado.txt \
--runner_mode=eval --output_dir= --logtostderr --runner_mode=eval --logtostderr --output_dir=/tmp/prado
``` ```
Evaluate SGNN model: Evaluate SGNN model:
......
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
"split_on_space": true, "split_on_space": true,
"embedding_regularizer_scale": 35e-3, "embedding_regularizer_scale": 35e-3,
"embedding_size": 64, "embedding_size": 64,
"heads": [0, 64, 64, 0, 0], "bigram_channels": 64,
"trigram_channels": 64,
"feature_size": 512, "feature_size": 512,
"network_regularizer_scale": 1e-4, "network_regularizer_scale": 1e-4,
"keep_prob": 0.5, "keep_prob": 0.5,
......
sh_binary(
name = "move_ops",
srcs = ["move_ops.sh"],
data = [
"//tf_ops:sequence_string_projection_op_py",
"//tf_ops:sequence_string_projection_op_v2_py",
"//tf_ops:tf_custom_ops_py",
],
)
#!/bin/bash
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
RUNFILES_DIR=$(pwd)
cp -f "${RUNFILES_DIR}/tf_ops/libsequence_string_projection_op_py_gen_op.so" \
"${BUILD_WORKSPACE_DIRECTORY}/tf_ops"
cp -f "${RUNFILES_DIR}/tf_ops/sequence_string_projection_op.py" \
"${BUILD_WORKSPACE_DIRECTORY}/tf_ops"
cp -f "${RUNFILES_DIR}/tf_ops/libsequence_string_projection_op_v2_py_gen_op.so" \
"${BUILD_WORKSPACE_DIRECTORY}/tf_ops"
cp -f "${RUNFILES_DIR}/tf_ops/sequence_string_projection_op_v2.py" \
"${BUILD_WORKSPACE_DIRECTORY}/tf_ops"
cp -f "${RUNFILES_DIR}/tf_ops/libtf_custom_ops_py_gen_op.so" \
"${BUILD_WORKSPACE_DIRECTORY}/tf_ops"
cp -f "${RUNFILES_DIR}/tf_ops/tf_custom_ops_py.py" \
"${BUILD_WORKSPACE_DIRECTORY}/tf_ops"
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from distutils import spawn
from distutils.command import build
import os
import subprocess
import setuptools
class _BuildCommand(build.build):
sub_commands = [
('bazel_build', lambda self: True),
] + build.build.sub_commands
class _BazelBuildCommand(setuptools.Command):
def initialize_options(self):
pass
def finalize_options(self):
self._bazel_cmd = spawn.find_executable('bazel')
def run(self):
subprocess.check_call(
[self._bazel_cmd, 'run', '-c', 'opt', '//demo/colab:move_ops'],
cwd=os.path.dirname(os.path.realpath(__file__)))
setuptools.setup(
name='seq_flow_lite',
version='0.1',
packages=['tf_ops'],
package_data={'': ['*.so']},
cmdclass={
'build': _BuildCommand,
'bazel_build': _BazelBuildCommand,
},
description='Test')
#!/bin/bash
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
cd "$(dirname "$0")"
mv setup.py ../..
touch ../../tf_ops/__init__.py
# A demo app for invoking a PRADO TFLite model.
licenses(["notice"])
package(
default_visibility = ["//:friends"], # sequence projection
)
cc_binary(
name = "prado_tflite_example",
srcs = ["prado_tflite_example.cc"],
data = [
"data/tflite.fb",
],
deps = [
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:string_util",
"//tflite_ops:expected_value", # sequence projection
"//tflite_ops:quantization_util", # sequence projection
"//tflite_ops:sequence_string_projection", # sequence projection
],
)
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstddef>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/string_util.h"
#include "tflite_ops/expected_value.h" // seq_flow_lite
#include "tflite_ops/quantization_util.h" // seq_flow_lite
#include "tflite_ops/sequence_string_projection.h" // seq_flow_lite
namespace {
const int kTextInput = 0;
const int kClassOutput = 0;
const int kNumberOfInputs = 1;
const int kNumberOfOutputs = 1;
const int kClassOutputRank = 2;
const int kClassOutputBatchSizeIndex = 0;
const int kBatchSize = 1;
const int kClassOutputClassIndex = 1;
constexpr char kTfliteDemoFile[] =
"demo/prado/data/tflite.fb";
std::unique_ptr<tflite::Interpreter> CreateInterpreter(
const std::string& tflite_flat_buffer) {
// This pointer points to a memory location contained in tflite_flat_buffer,
// hence it need not be deleted.
const tflite::Model* model = tflite::GetModel(tflite_flat_buffer.data());
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::ops::builtin::BuiltinOpResolver resolver;
resolver.AddCustom(
"SEQUENCE_STRING_PROJECTION",
tflite::ops::custom::Register_SEQUENCE_STRING_PROJECTION());
resolver.AddCustom("ExpectedValueOp",
tflite::ops::custom::Register_EXPECTED_VALUE());
tflite::InterpreterBuilder(model, resolver,
/*error_reporter=*/nullptr)(&interpreter);
if (!interpreter) {
std::cout << "Unable to create tflite interpreter\n";
}
return interpreter;
}
std::vector<float> InvokeModel(
const std::string& text,
std::unique_ptr<tflite::Interpreter>& interpreter) {
std::vector<float> classes;
auto inputs = interpreter->inputs();
if (inputs.size() != kNumberOfInputs) {
std::cerr << "Model does not accept the right number of inputs.";
return classes;
}
// Set input to the model.
TfLiteTensor* input = interpreter->tensor(inputs[kTextInput]);
tflite::DynamicBuffer buf;
buf.AddString(text.data(), text.length());
buf.WriteToTensorAsVector(input);
// Allocate buffers.
interpreter->AllocateTensors();
// Invoke inference on the model.
interpreter->Invoke();
// Extract outputs and perform sanity checks on them.
auto outputs = interpreter->outputs();
if (outputs.size() != kNumberOfOutputs) {
std::cerr << "Model does not produce right number of outputs.";
return classes;
}
TfLiteTensor* class_output = interpreter->tensor(outputs[kClassOutput]);
if (class_output->type != kTfLiteUInt8) {
std::cerr << "Tensor output types are not as expected.";
return classes;
}
if (class_output->dims->size != kClassOutputRank) {
std::cerr << "Tensor output should be rank " << kClassOutputRank;
return classes;
}
const auto output_dims = class_output->dims->data;
if (output_dims[kClassOutputBatchSizeIndex] != kBatchSize) {
std::cerr << "Batch size is expected to be " << kBatchSize;
return classes;
}
// Extract output from the output tensor and populate results.
const size_t num_classes = output_dims[kClassOutputClassIndex];
for (int i = 0; i < num_classes; ++i) {
// Find class probability or log probability for the class index
classes.push_back(tflite::PodDequantize(*class_output, i));
}
return classes;
}
std::string GetTfliteDemoFile() {
std::string tflite_flat_buffer;
std::ifstream file(kTfliteDemoFile,
std::ios::in | std::ios::binary | std::ios::ate);
if (!file.is_open()) {
std::cerr << "Unable to open demo tflite file.\n";
return tflite_flat_buffer;
}
size_t size = file.tellg();
file.seekg(0, file.beg);
tflite_flat_buffer.resize(size);
file.read(const_cast<char*>(tflite_flat_buffer.data()), size);
file.close();
return tflite_flat_buffer;
}
} // namespace
int main(int argc, char** argv) {
// The flatbuffer must remain valid until the interpreter is destroyed.
std::string tflite_flat_buffer = GetTfliteDemoFile();
if (tflite_flat_buffer.empty()) {
return EXIT_FAILURE;
}
auto interpreter = CreateInterpreter(tflite_flat_buffer);
if (!interpreter) {
return EXIT_FAILURE;
}
while (true) {
std::string sentence;
std::cout << "Enter input: ";
std::getline(std::cin, sentence);
std::vector<float> classes = InvokeModel(sentence, interpreter);
for (float class_value : classes) {
std::cout << class_value << std::endl;
}
}
return EXIT_SUCCESS;
}
...@@ -72,7 +72,7 @@ py_strict_library( ...@@ -72,7 +72,7 @@ py_strict_library(
# package tensorflow # package tensorflow
# "//tf_ops:sequence_string_projection_op" # sequence projection # "//tf_ops:sequence_string_projection_op" # sequence projection
"//tf_ops:sequence_string_projection_op_py", # sequence projection "//tf_ops:sequence_string_projection_op_py", # sequence projection
"//tf_ops:sequence_string_projection_op_v2", # sequence projection # "//tf_ops:sequence_string_projection_op_v2" # sequence projection
"//tf_ops:sequence_string_projection_op_v2_py", # sequence projection "//tf_ops:sequence_string_projection_op_v2_py", # sequence projection
], ],
) )
...@@ -29,12 +29,14 @@ class BaseQDense(base_layers.BaseLayer): ...@@ -29,12 +29,14 @@ class BaseQDense(base_layers.BaseLayer):
activation=tf.keras.layers.ReLU(), activation=tf.keras.layers.ReLU(),
bias=True, bias=True,
rank=2, rank=2,
normalize=True,
**kwargs): **kwargs):
self.units = units self.units = units
self.rank = rank self.rank = rank
assert rank >= 2 and rank <= 4 assert rank >= 2 and rank <= 4
self.activation = activation self.activation = activation
self.bias = bias self.bias = bias
self.normalize = normalize
self.qoutput = quantization_layers.ActivationQuantization(**kwargs) self.qoutput = quantization_layers.ActivationQuantization(**kwargs)
self._create_normalizer(**kwargs) self._create_normalizer(**kwargs)
super(BaseQDense, self).__init__(**kwargs) super(BaseQDense, self).__init__(**kwargs)
...@@ -56,6 +58,7 @@ class BaseQDense(base_layers.BaseLayer): ...@@ -56,6 +58,7 @@ class BaseQDense(base_layers.BaseLayer):
outputs = tf.matmul(inputs, self.w) outputs = tf.matmul(inputs, self.w)
if self.bias: if self.bias:
outputs = tf.nn.bias_add(outputs, self.b) outputs = tf.nn.bias_add(outputs, self.b)
if self.normalize:
outputs = normalize_method(outputs) outputs = normalize_method(outputs)
if self.activation: if self.activation:
outputs = self.activation(outputs) outputs = self.activation(outputs)
......
...@@ -136,5 +136,5 @@ class LayerNormalization(base_layers.BaseLayer): ...@@ -136,5 +136,5 @@ class LayerNormalization(base_layers.BaseLayer):
tensor = (tensor - mean) / tf.sqrt(variance + 1e-6) tensor = (tensor - mean) / tf.sqrt(variance + 1e-6)
return tensor * self.scale + self.offset return tensor * self.scale + self.offset
else: else:
return tf_custom_ops_py.layer_norm_v2( return tf_custom_ops_py.layer_norm(
tensor, self.scale, self.offset, axes=self.axes) tensor, self.scale, self.offset, axes=self.axes)
...@@ -39,6 +39,7 @@ class ProjectionLayer(base_layers.BaseLayer): ...@@ -39,6 +39,7 @@ class ProjectionLayer(base_layers.BaseLayer):
_get_params("max_seq_len", 0) _get_params("max_seq_len", 0)
_get_params("add_eos_tag", False) _get_params("add_eos_tag", False)
_get_params("add_bos_tag", False) _get_params("add_bos_tag", False)
_get_params("hashtype", "murmur")
_get_params("split_on_space", True) _get_params("split_on_space", True)
_get_params("token_separators", "") _get_params("token_separators", "")
_get_params("vocabulary", "") _get_params("vocabulary", "")
...@@ -56,6 +57,7 @@ class ProjectionLayer(base_layers.BaseLayer): ...@@ -56,6 +57,7 @@ class ProjectionLayer(base_layers.BaseLayer):
input=inputs, input=inputs,
feature_size=self.feature_size, feature_size=self.feature_size,
max_splits=self.max_seq_len - 1, max_splits=self.max_seq_len - 1,
hashtype=self.hashtype,
distortion_probability=self.distortion_probability, distortion_probability=self.distortion_probability,
split_on_space=self.split_on_space, split_on_space=self.split_on_space,
token_separators=self.token_separators, token_separators=self.token_separators,
...@@ -69,7 +71,7 @@ class ProjectionLayer(base_layers.BaseLayer): ...@@ -69,7 +71,7 @@ class ProjectionLayer(base_layers.BaseLayer):
if self.mode not in modes and self.max_seq_len > 0: if self.mode not in modes and self.max_seq_len > 0:
short_by = self.max_seq_len - tf.shape(projection)[1] short_by = self.max_seq_len - tf.shape(projection)[1]
projection = tf.pad(projection, [[0, 0], [0, short_by], [0, 0]]) projection = tf.pad(projection, [[0, 0], [0, short_by], [0, 0]])
batch_size = inputs.get_shape().as_list()[0] batch_size = self.get_batch_dimension(inputs)
projection = tf.reshape(projection, projection = tf.reshape(projection,
[batch_size, self.max_seq_len, self.feature_size]) [batch_size, self.max_seq_len, self.feature_size])
if self.mode in modes: if self.mode in modes:
......
...@@ -27,6 +27,9 @@ def classification_metric(per_example_loss, label_ids, logits): ...@@ -27,6 +27,9 @@ def classification_metric(per_example_loss, label_ids, logits):
} }
THRESHOLDS = [0.5]
def labeling_metric(per_example_loss, label_ids, logits): def labeling_metric(per_example_loss, label_ids, logits):
"""Compute eval metrics.""" """Compute eval metrics."""
scores = tf.math.sigmoid(logits) scores = tf.math.sigmoid(logits)
...@@ -35,4 +38,10 @@ def labeling_metric(per_example_loss, label_ids, logits): ...@@ -35,4 +38,10 @@ def labeling_metric(per_example_loss, label_ids, logits):
for idx in range(num_classes): for idx in range(num_classes):
return_dict["auc/" + str(idx)] = tf.metrics.auc(label_ids[:, idx], return_dict["auc/" + str(idx)] = tf.metrics.auc(label_ids[:, idx],
scores[:, idx]) scores[:, idx])
return_dict["precision@" + str(THRESHOLDS) + "/" +
str(idx)] = tf.metrics.precision_at_thresholds(
label_ids[:, idx], scores[:, idx], thresholds=THRESHOLDS)
return_dict["recall@" + str(THRESHOLDS) + "/" +
str(idx)] = tf.metrics.recall_at_thresholds(
label_ids[:, idx], scores[:, idx], thresholds=THRESHOLDS)
return return_dict return return_dict
...@@ -101,6 +101,8 @@ cc_library( ...@@ -101,6 +101,8 @@ cc_library(
":text_distorter", ":text_distorter",
"@tensorflow_includes//:includes", "@tensorflow_includes//:includes",
"@tensorflow_solib//:framework_lib", "@tensorflow_solib//:framework_lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/random",
], ],
alwayslink = 1, alwayslink = 1,
) )
......
...@@ -26,7 +26,7 @@ limitations under the License. ...@@ -26,7 +26,7 @@ limitations under the License.
bool IsDigit(const std::string& text) { bool IsDigit(const std::string& text) {
Rune rune; Rune rune;
for (size_t i = 0; i < text.length();) { for (size_t i = 0; i < text.length();) {
const int bytes_read = charntorune(&rune, text.data(), 1); const int bytes_read = chartorune(&rune, const_cast<char *>(text.data()));
if (rune == Runeerror || bytes_read == 0) break; if (rune == Runeerror || bytes_read == 0) break;
if (rune >= static_cast<Rune>('0') && rune <= static_cast<Rune>('9')) { if (rune >= static_cast<Rune>('0') && rune <= static_cast<Rune>('9')) {
return true; return true;
......
...@@ -14,18 +14,223 @@ limitations under the License. ...@@ -14,18 +14,223 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tf_ops/projection_util.h" // seq_flow_lite #include "tf_ops/projection_util.h" // seq_flow_lite
#include <cassert>
#include <cstddef> #include <cstddef>
#include <cstdint>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <unordered_set>
namespace { namespace {
constexpr size_t kInvalid = -1; constexpr int kInvalid = -1;
constexpr char kSpace = ' '; constexpr char kSpace = ' ';
} // namespace } // namespace
class MurmurHash : public HashEngine {
public:
void GetHashCodes(const std::string& word, std::vector<uint64_t>* hash_codes,
int feature_size) override {
uint64_t hash_low = 0;
uint64_t hash_high = 0;
for (int i = 0; i < feature_size; i += 64) {
if (i == 0) {
auto hash = MurmurHash128(word.c_str(), word.size());
hash_low = hash.first;
hash_high = hash.second;
} else {
GetMoreBits(hash_low, hash_high, &hash_low, &hash_high);
}
hash_codes->push_back(hash_low);
hash_codes->push_back(hash_high);
}
}
private:
static constexpr uint64_t kMul = 0xc6a4a7935bd1e995ULL;
static constexpr uint64_t kMul2 = 0x9e3779b97f4a7835ULL;
inline uint64_t ShiftMix(uint64_t val) { return val ^ (val >> 47); }
inline uint64_t MurmurStep(uint64_t hash, uint64_t data) {
hash ^= ShiftMix(data * kMul) * kMul;
hash *= kMul;
return hash;
}
inline uint64_t Load64VariableLength(const void* p, int len) {
assert(len >= 1 && len <= 8);
const char* buf = static_cast<const char*>(p);
uint64_t val = 0;
--len;
do {
val = (val << 8) | buf[len];
// (--len >= 0) is about 10 % faster than (len--) in some benchmarks.
} while (--len >= 0);
// No ToHost64(...) needed. The bytes are accessed in little-endian manner
// on every architecture.
return val;
}
void GetMoreBits(uint64_t hash, uint64_t hash2, uint64_t* rlow,
uint64_t* rhigh) {
hash = ShiftMix(hash) * kMul;
hash2 ^= hash;
*rhigh = ShiftMix(hash);
*rlow = ShiftMix(hash2 * kMul2) * kMul2;
}
std::pair<uint64_t, uint64_t> MurmurHash128(const char* buf,
const size_t len) {
// Initialize the hashing value.
uint64_t hash = len * kMul;
// hash2 will be xored by hash during the hash computation iterations.
// In the end we use an alternative mixture multiplier for mixing
// the bits in hash2.
uint64_t hash2 = 0;
// Let's remove the bytes not divisible by the sizeof(uint64_t).
// This allows the inner loop to process the data as 64 bit integers.
const size_t len_aligned = len & ~0x7;
const char* end = buf + len_aligned;
for (const char* p = buf; p != end; p += 8) {
// Manually unrolling this loop 2x did not help on Intel Core 2.
hash = MurmurStep(hash, Load64VariableLength(p, 8));
hash2 ^= hash;
}
if ((len & 0x7) != 0) {
const uint64_t data = Load64VariableLength(end, len & 0x7);
hash ^= data;
hash *= kMul;
hash2 ^= hash;
}
hash = ShiftMix(hash) * kMul;
hash2 ^= hash;
hash = ShiftMix(hash);
// mul2 is a prime just above golden ratio. mul2 is used to ensure that the
// impact of the last few bytes is different to the upper and lower 64 bits.
hash2 = ShiftMix(hash2 * kMul2) * kMul2;
return std::make_pair(hash, hash2);
}
};
class XFixHash : public HashEngine {
public:
explicit XFixHash(int bits_per_char)
: bits_per_char_(bits_per_char), bit_mask_((1ULL << bits_per_char) - 1) {}
void GetHashCodes(const std::string& word, std::vector<uint64_t>* hash_codes,
int feature_size) override {
auto token_ptr = reinterpret_cast<const uint8_t*>(word.c_str());
size_t token_size = word.size();
int token_idx = 0;
uint64_t hash_low = token_size * kMul;
uint64_t hash_high = token_size * kMul2;
uint64_t frhash = kMul;
uint64_t brhash = kMul2;
for (int i = 0; i < feature_size; i += 64) {
for (int j = i ? 0 : bits_per_char_; j < 64;
j += bits_per_char_, token_idx = (token_idx + 1) % token_size) {
frhash = ((frhash << 8) | token_ptr[token_idx]) * kMul;
brhash =
((brhash << 8) | token_ptr[token_size - 1 - token_idx]) * kMul2;
hash_low = (hash_low << bits_per_char_) | (frhash & bit_mask_);
hash_high = (hash_high << bits_per_char_) | (brhash & bit_mask_);
}
hash_codes->push_back(hash_low);
hash_codes->push_back(hash_high);
}
}
private:
const uint64_t kMul = 0xc6a4a7935bd1e995ULL;
const uint64_t kMul2 = 0x9e3779b97f4a7835ULL;
const int bits_per_char_;
const uint64_t bit_mask_;
};
class UnicodeHash : public HashEngine {
public:
// bits_per_unicode should be a divisor of 64.
explicit UnicodeHash(int bits_per_unicode)
: bits_per_unicode_(bits_per_unicode),
bit_mask_(((1ULL << bits_per_unicode) - 1) << (64 - bits_per_unicode)) {
}
void GetHashCodes(const std::string& word, std::vector<uint64_t>* hash_codes,
int feature_size) override {
auto word_ptr = word.c_str();
int utflength = utflen(const_cast<char*>(word_ptr));
// Both `feature_size` and `bits_per_unicode` are bit lengths.
const int max_usable_runes = feature_size * 2 / bits_per_unicode_;
if (max_usable_runes < utflength) {
const int unicode_skip = (utflength - max_usable_runes) / 2;
for (int i = 0; i < unicode_skip; ++i) {
Rune rune;
word_ptr += chartorune(&rune, const_cast<char*>(word_ptr));
}
utflength = max_usable_runes;
}
std::vector<uint64_t> unicode_hashes;
unicode_hashes.reserve(utflength);
for (int i = 0; i < utflength; ++i) {
Rune rune;
word_ptr += chartorune(&rune, const_cast<char*>(word_ptr));
unicode_hashes.push_back((rune * kMul) & bit_mask_);
}
uint64_t hash = 0;
int k = 0;
for (int i = 0; i < feature_size * 2; i += 64) {
for (int j = 0; j < 64; j += bits_per_unicode_) {
if (k < unicode_hashes.size()) {
hash = (hash >> bits_per_unicode_) | unicode_hashes[k++];
} else {
hash = hash >> bits_per_unicode_;
}
}
hash_codes->push_back(hash);
}
}
private:
const uint64_t kMul = 0xc6a4a7935bd1e995ULL;
const int bits_per_unicode_;
const uint64_t bit_mask_;
};
bool Hasher::SupportedHashType(const std::string& hash_type) {
std::unordered_set<std::string> supported({kMurmurHash, kUnicodeHash8,
kUnicodeHash16, kXfixHash8,
kXfixHash16, kXfixHash32});
return supported.find(hash_type) != supported.end();
}
Hasher* Hasher::CreateHasher(int feature_size, const std::string& hash_type) {
if (SupportedHashType(hash_type)) {
if (hash_type == kMurmurHash) {
return new Hasher(feature_size, new MurmurHash());
} else if (hash_type == kUnicodeHash8) {
return new Hasher(feature_size, new UnicodeHash(8));
} else if (hash_type == kUnicodeHash16) {
return new Hasher(feature_size, new UnicodeHash(16));
} else if (hash_type == kXfixHash8) {
return new Hasher(feature_size, new XFixHash(8));
} else if (hash_type == kXfixHash16) {
return new Hasher(feature_size, new XFixHash(16));
} else {
return new Hasher(feature_size, new XFixHash(32));
}
}
return nullptr;
}
Hasher::Hasher(int feature_size, HashEngine* hash_engine)
: feature_size_(feature_size), hash_engine_(hash_engine) {
hash_engine_->GetHashCodes(empty_string_, &null_hash_codes_, feature_size_);
}
std::string ProjectionUnicodeHandler::LowerCaseUTF8WithSupportedUnicodes( std::string ProjectionUnicodeHandler::LowerCaseUTF8WithSupportedUnicodes(
const std::pair<const char*, size_t>& source) const { const std::pair<const char*, size_t>& source, bool* first_cap,
bool* all_caps) const {
// Ideally the size of target should be less than or equal to source. But // Ideally the size of target should be less than or equal to source. But
// when we do to_lower the number of bytes needed to encode a unicode // when we do to_lower the number of bytes needed to encode a unicode
// character could increase. To account for this 4 times the source length // character could increase. To account for this 4 times the source length
...@@ -35,17 +240,20 @@ std::string ProjectionUnicodeHandler::LowerCaseUTF8WithSupportedUnicodes( ...@@ -35,17 +240,20 @@ std::string ProjectionUnicodeHandler::LowerCaseUTF8WithSupportedUnicodes(
auto target = std::unique_ptr<char[]>(new char[len * 4]); auto target = std::unique_ptr<char[]>(new char[len * 4]);
auto target_ptr = target.get(); auto target_ptr = target.get();
int i = 0; int i = 0;
bool first_char = true;
bool first_cap_value = false;
bool all_caps_value = false;
while (i < len) { while (i < len) {
Rune rune; Rune rune;
const int bytes_read = charntorune(&rune, csource + i, len - i); const int bytes_read = chartorune(&rune, const_cast<char*>(csource + i));
if (bytes_read == 0) { if (bytes_read == 0 || bytes_read > len - i) {
break; break;
} }
i += bytes_read; i += bytes_read;
if (rune != Runeerror) { if (rune != Runeerror) {
Rune lower = tolowerrune(rune); Rune lower = tolowerrune(rune);
// Skip processing the unicode if exclude_nonalphaspace_unicodes_ is true // Skip processing the unicode if exclude_nonalphaspace_unicodes_ is
// and the unicode is not alpha and not space. // true and the unicode is not alpha and not space.
const Rune kSpaceRune = ' '; const Rune kSpaceRune = ' ';
if (exclude_nonalphaspace_unicodes_ && !isalpharune(lower) && if (exclude_nonalphaspace_unicodes_ && !isalpharune(lower) &&
lower != kSpaceRune) { lower != kSpaceRune) {
...@@ -54,8 +262,24 @@ std::string ProjectionUnicodeHandler::LowerCaseUTF8WithSupportedUnicodes( ...@@ -54,8 +262,24 @@ std::string ProjectionUnicodeHandler::LowerCaseUTF8WithSupportedUnicodes(
if (IsUnrestrictedVocabulary() || IsValidUnicode(lower)) { if (IsUnrestrictedVocabulary() || IsValidUnicode(lower)) {
const int bytes_written = runetochar(target_ptr, &lower); const int bytes_written = runetochar(target_ptr, &lower);
target_ptr += bytes_written; target_ptr += bytes_written;
const bool lower_case = (lower == rune);
if (first_char) {
first_cap_value = !lower_case;
all_caps_value = !lower_case;
} else {
first_cap_value &= lower_case;
all_caps_value &= !lower_case;
}
first_char = false;
}
}
} }
if (first_cap) {
*first_cap = first_cap_value;
} }
if (all_caps) {
*all_caps = all_caps_value;
} }
return std::string(target.get(), target_ptr); return std::string(target.get(), target_ptr);
} }
...@@ -65,8 +289,8 @@ void ProjectionUnicodeHandler::InitializeVocabulary( ...@@ -65,8 +289,8 @@ void ProjectionUnicodeHandler::InitializeVocabulary(
for (size_t i = 0, index = 0; i < vocabulary.length();) { for (size_t i = 0, index = 0; i < vocabulary.length();) {
Rune rune; Rune rune;
const int bytes_read = const int bytes_read =
charntorune(&rune, vocabulary.c_str() + i, vocabulary.length() - i); chartorune(&rune, const_cast<char*>(vocabulary.c_str() + i));
if (!bytes_read) { if (!bytes_read || bytes_read > (vocabulary.length() - i)) {
break; break;
} }
i += bytes_read; i += bytes_read;
...@@ -84,7 +308,8 @@ void ProjectionUnicodeHandler::InitializeVocabulary( ...@@ -84,7 +308,8 @@ void ProjectionUnicodeHandler::InitializeVocabulary(
} }
// Starting from input_ptr[from], search for the next occurrence of ' ', // Starting from input_ptr[from], search for the next occurrence of ' ',
// Don't search beyond input_ptr[length](non-inclusive), return -1 if not found. // Don't search beyond input_ptr[length](non-inclusive), return -1 if not
// found.
inline size_t FindNextSpace(const char* input_ptr, size_t from, size_t length) { inline size_t FindNextSpace(const char* input_ptr, size_t from, size_t length) {
size_t space_index; size_t space_index;
for (space_index = from; space_index < length; space_index++) { for (space_index = from; space_index < length; space_index++) {
...@@ -141,8 +366,8 @@ void SplitByCharInternal(std::vector<T>* tokens, const char* input_ptr, ...@@ -141,8 +366,8 @@ void SplitByCharInternal(std::vector<T>* tokens, const char* input_ptr,
size_t len, size_t max_tokens) { size_t len, size_t max_tokens) {
Rune rune; Rune rune;
for (size_t i = 0; i < len;) { for (size_t i = 0; i < len;) {
auto bytes_read = charntorune(&rune, input_ptr + i, len - i); auto bytes_read = chartorune(&rune, const_cast<char*>(input_ptr + i));
if (bytes_read == 0) break; if (bytes_read == 0 || bytes_read > (len - i)) break;
tokens->emplace_back(input_ptr + i, bytes_read); tokens->emplace_back(input_ptr + i, bytes_read);
if (max_tokens != kInvalid && tokens->size() == max_tokens) { if (max_tokens != kInvalid && tokens->size() == max_tokens) {
break; break;
...@@ -181,7 +406,7 @@ std::string JoinPairsBySpace( ...@@ -181,7 +406,7 @@ std::string JoinPairsBySpace(
} }
std::vector<std::pair<const char*, size_t>> ProjectionUnicodeHandler::Tokenize( std::vector<std::pair<const char*, size_t>> ProjectionUnicodeHandler::Tokenize(
const char* str, size_t len, bool by_space, int max_tokens) const { const char* str, size_t len, bool by_space, int max_tokens) {
return by_space ? SplitBySpaceAsPairs(str, len, max_tokens) return by_space ? SplitBySpaceAsPairs(str, len, max_tokens)
: SplitByCharAsPairs(str, len, max_tokens); : SplitByCharAsPairs(str, len, max_tokens);
} }
...@@ -14,122 +14,58 @@ limitations under the License. ...@@ -14,122 +14,58 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_UTIL_H_ #ifndef TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_UTIL_H_
#define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_UTIL_H_ #define TENSORFLOW_MODELS_SEQUENCE_PROJECTION_TF_OPS_PROJECTION_UTIL_H_
#include <memory>
#include <cassert>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "libutf/utf.h" #include "libutf/utf.h"
inline int charntorune(Rune* r, const char* s, int n) { constexpr int kFirstCapOffset = 3;
const int bytes_read = chartorune(r, const_cast<char *>(s)); constexpr int kAllCapsOffset = 4;
if (bytes_read > n) { constexpr int kWordNoveltyOffset = 1;
*r = Runeerror; constexpr int kDocSizeOffset = 2;
return 0;
} const char kMurmurHash[] = "murmur";
return bytes_read; const char kXfixHash8[] = "xfixhash8";
} const char kXfixHash16[] = "xfixhash16";
const char kXfixHash32[] = "xfixhash32";
const char kUnicodeHash8[] = "unicodehash8";
const char kUnicodeHash16[] = "unicodehash16";
class HashEngine {
public:
virtual void GetHashCodes(const std::string& word,
std::vector<uint64_t>* hash_codes,
int feature_size) = 0;
virtual ~HashEngine() {}
};
// A hashing wrapper class that can hash a string and generate a hash code with // A hashing wrapper class that can hash a string and generate a hash code with
// requested number of features (two bit values). Some of the implementations // requested number of features (two bit values). Some of the implementations
// are copied from murmurhash. // are copied from murmurhash.
class Hasher { class Hasher {
public: public:
explicit Hasher(int feature_size) : feature_size_(feature_size) { static Hasher* CreateHasher(int feature_size,
GetHashCodesInternal(empty_string_, &null_hash_codes_); const std::string& hash_type = kMurmurHash);
} static bool SupportedHashType(const std::string& hash_type);
void GetHashCodes(const std::string& word, bool GetHashCodes(const std::string& word,
std::vector<uint64_t>* hash_codes) { std::vector<uint64_t>* hash_codes) {
if (!hash_engine_) return false;
if (word.empty()) { if (word.empty()) {
*hash_codes = null_hash_codes_; *hash_codes = null_hash_codes_;
} else { } else {
hash_codes->clear(); hash_codes->clear();
GetHashCodesInternal(word, hash_codes); hash_engine_->GetHashCodes(word, hash_codes, feature_size_);
} }
return true;
} }
private: private:
static constexpr uint64_t kMul = 0xc6a4a7935bd1e995ULL; explicit Hasher(int feature_size, HashEngine* hash_engine);
static constexpr uint64_t kMul2 = 0x9e3779b97f4a7835ULL;
inline uint64_t ShiftMix(uint64_t val) { return val ^ (val >> 47); }
inline uint64_t MurmurStep(uint64_t hash, uint64_t data) {
hash ^= ShiftMix(data * kMul) * kMul;
hash *= kMul;
return hash;
}
inline uint64_t Load64VariableLength(const void* p, int len) {
assert(len >= 1 && len <= 8);
const char* buf = static_cast<const char*>(p);
uint64_t val = 0;
--len;
do {
val = (val << 8) | buf[len];
// (--len >= 0) is about 10 % faster than (len--) in some benchmarks.
} while (--len >= 0);
// No ToHost64(...) needed. The bytes are accessed in little-endian manner
// on every architecture.
return val;
}
void GetMoreBits(uint64_t hash, uint64_t hash2, uint64_t* rlow,
uint64_t* rhigh) {
hash = ShiftMix(hash) * kMul;
hash2 ^= hash;
*rhigh = ShiftMix(hash);
*rlow = ShiftMix(hash2 * kMul2) * kMul2;
}
std::pair<uint64_t, uint64_t> MurmurHash128(const char* buf,
const size_t len) {
// Initialize the hashing value.
uint64_t hash = len * kMul;
// hash2 will be xored by hash during the hash computation iterations.
// In the end we use an alternative mixture multiplier for mixing
// the bits in hash2.
uint64_t hash2 = 0;
// Let's remove the bytes not divisible by the sizeof(uint64_t).
// This allows the inner loop to process the data as 64 bit integers.
const size_t len_aligned = len & ~0x7;
const char* end = buf + len_aligned;
for (const char* p = buf; p != end; p += 8) {
// Manually unrolling this loop 2x did not help on Intel Core 2.
hash = MurmurStep(hash, Load64VariableLength(p, 8));
hash2 ^= hash;
}
if ((len & 0x7) != 0) {
const uint64_t data = Load64VariableLength(end, len & 0x7);
hash ^= data;
hash *= kMul;
hash2 ^= hash;
}
hash = ShiftMix(hash) * kMul;
hash2 ^= hash;
hash = ShiftMix(hash);
// mul2 is a prime just above golden ratio. mul2 is used to ensure that the
// impact of the last few bytes is different to the upper and lower 64 bits.
hash2 = ShiftMix(hash2 * kMul2) * kMul2;
return std::make_pair(hash, hash2);
}
void GetHashCodesInternal(const std::string& word,
std::vector<uint64_t>* hash_codes) {
uint64_t hash_low = 0;
uint64_t hash_high = 0;
for (int i = 0; i < feature_size_; i += 64) {
if (i == 0) {
auto hash = MurmurHash128(word.c_str(), word.size());
hash_low = hash.first;
hash_high = hash.second;
} else {
GetMoreBits(hash_low, hash_high, &hash_low, &hash_high);
}
hash_codes->push_back(hash_low);
hash_codes->push_back(hash_high);
}
}
const std::string empty_string_ = "<null>"; const std::string empty_string_ = "<null>";
const int feature_size_; const int feature_size_;
std::unique_ptr<HashEngine> hash_engine_;
std::vector<uint64_t> null_hash_codes_; std::vector<uint64_t> null_hash_codes_;
}; };
...@@ -156,7 +92,8 @@ class ProjectionUnicodeHandler { ...@@ -156,7 +92,8 @@ class ProjectionUnicodeHandler {
// Performs language independent lower case and returns a string with // Performs language independent lower case and returns a string with
// supported unicode segments. // supported unicode segments.
std::string LowerCaseUTF8WithSupportedUnicodes( std::string LowerCaseUTF8WithSupportedUnicodes(
const std::pair<const char*, size_t>& source) const; const std::pair<const char*, size_t>& source, bool* first_cap = nullptr,
bool* all_caps = nullptr) const;
// Returns a boolean flag indicating if the unicode segment is part of the // Returns a boolean flag indicating if the unicode segment is part of the
// vocabulary. // vocabulary.
...@@ -179,15 +116,14 @@ class ProjectionUnicodeHandler { ...@@ -179,15 +116,14 @@ class ProjectionUnicodeHandler {
// Tokenizes input by space or unicode point segmentation. Limit to // Tokenizes input by space or unicode point segmentation. Limit to
// max_tokens, when it is not -1. // max_tokens, when it is not -1.
std::vector<std::pair<const char*, size_t>> Tokenize(const std::string& input, static std::vector<std::pair<const char*, size_t>> Tokenize(
bool by_space, const std::string& input, bool by_space, int max_tokens) {
int max_tokens) const {
return Tokenize(input.c_str(), input.size(), by_space, max_tokens); return Tokenize(input.c_str(), input.size(), by_space, max_tokens);
} }
std::vector<std::pair<const char*, size_t>> Tokenize(const char* str, static std::vector<std::pair<const char*, size_t>> Tokenize(const char* str,
size_t len, size_t len,
bool by_space, bool by_space,
int max_tokens) const; int max_tokens);
private: private:
// Parses and extracts supported unicode segments from a utf8 string. // Parses and extracts supported unicode segments from a utf8 string.
......
...@@ -25,13 +25,13 @@ limitations under the License. ...@@ -25,13 +25,13 @@ limitations under the License.
using ::tensorflow::int32; using ::tensorflow::int32;
using ::tensorflow::int64; using ::tensorflow::int64;
using ::tensorflow::uint64;
using ::tensorflow::OpKernel; using ::tensorflow::OpKernel;
using ::tensorflow::OpKernelConstruction; using ::tensorflow::OpKernelConstruction;
using ::tensorflow::OpKernelContext; using ::tensorflow::OpKernelContext;
using ::tensorflow::Tensor; using ::tensorflow::Tensor;
using ::tensorflow::TensorShape; using ::tensorflow::TensorShape;
using ::tensorflow::TensorShapeUtils; using ::tensorflow::TensorShapeUtils;
using ::tensorflow::uint64;
using ::tensorflow::errors::InvalidArgument; using ::tensorflow::errors::InvalidArgument;
using tensorflow::shape_inference::DimensionHandle; using tensorflow::shape_inference::DimensionHandle;
...@@ -56,7 +56,11 @@ class SequenceStringProjectionOp : public OpKernel { ...@@ -56,7 +56,11 @@ class SequenceStringProjectionOp : public OpKernel {
explicit SequenceStringProjectionOp(OpKernelConstruction* context) explicit SequenceStringProjectionOp(OpKernelConstruction* context)
: OpKernel(context) { : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("feature_size", &feature_size_)); OP_REQUIRES_OK(context, context->GetAttr("feature_size", &feature_size_));
hasher_ = absl::make_unique<Hasher>(feature_size_); std::string hashtype;
OP_REQUIRES_OK(context, context->GetAttr("hashtype", &hashtype));
hasher_ =
absl::WrapUnique<Hasher>(Hasher::CreateHasher(feature_size_, hashtype));
CHECK(hasher_);
float distortion_probability = 0.0; float distortion_probability = 0.0;
OP_REQUIRES_OK(context, context->GetAttr("distortion_probability", OP_REQUIRES_OK(context, context->GetAttr("distortion_probability",
&distortion_probability)); &distortion_probability));
...@@ -110,6 +114,22 @@ class SequenceStringProjectionOp : public OpKernel { ...@@ -110,6 +114,22 @@ class SequenceStringProjectionOp : public OpKernel {
projection_normalizer_ = absl::make_unique<ProjectionNormalizer>( projection_normalizer_ = absl::make_unique<ProjectionNormalizer>(
separators, normalize_repetition); separators, normalize_repetition);
} }
OP_REQUIRES_OK(context, context->GetAttr("add_first_cap_feature",
&add_first_cap_feature_));
CHECK_GE(add_first_cap_feature_, 0.0);
CHECK_LE(add_first_cap_feature_, 1.0);
if (add_first_cap_feature_ > 0.0) {
CHECK_GE(feature_size_, 3);
}
OP_REQUIRES_OK(context, context->GetAttr("add_all_caps_feature",
&add_all_caps_feature_));
CHECK_GE(add_all_caps_feature_, 0.0);
CHECK_LE(add_all_caps_feature_, 1.0);
if (add_all_caps_feature_ > 0.0) {
CHECK_GE(feature_size_, 4);
}
} }
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
...@@ -173,13 +193,15 @@ class SequenceStringProjectionOp : public OpKernel { ...@@ -173,13 +193,15 @@ class SequenceStringProjectionOp : public OpKernel {
doc_size_feature = std::min(doc_size_feature, 1.0f) * 2.0f - 1.0f; doc_size_feature = std::min(doc_size_feature, 1.0f) * 2.0f - 1.0f;
for (int64 j = -bos_tag_; j < num_tokens + eos_tag_; ++j) { for (int64 j = -bos_tag_; j < num_tokens + eos_tag_; ++j) {
std::string word; std::string word;
bool first_cap = false;
bool all_caps = false;
if (j < 0) { if (j < 0) {
// Use a special tag for begin of sentence. // Use a special tag for begin of sentence.
word = kBeginTokenTSP; word = kBeginTokenTSP;
} else if (j < num_tokens) { } else if (j < num_tokens) {
auto uword = icu::UnicodeString::fromUTF8( auto uword = icu::UnicodeString::fromUTF8(
unicode_handler_->LowerCaseUTF8WithSupportedUnicodes( unicode_handler_->LowerCaseUTF8WithSupportedUnicodes(
words_batches[i][j])); words_batches[i][j], &first_cap, &all_caps));
word = text_distorter_->DistortText(&uword); word = text_distorter_->DistortText(&uword);
} else { } else {
// Use a special tag for end of sentence. // Use a special tag for end of sentence.
...@@ -196,14 +218,31 @@ class SequenceStringProjectionOp : public OpKernel { ...@@ -196,14 +218,31 @@ class SequenceStringProjectionOp : public OpKernel {
} }
if (word_novelty_bits_ != 0 && !hash_codes.empty()) { if (word_novelty_bits_ != 0 && !hash_codes.empty()) {
const auto word_hash = hash_codes[0]; const auto word_hash = hash_codes[0];
projection[offset0 + feature_size_ - 1] = projection[offset0 + feature_size_ - kWordNoveltyOffset] =
std::min((word_counter[word_hash]++ * word_novelty_offset_), std::min((word_counter[word_hash]++ * word_novelty_offset_),
1.0f) * 1.0f) *
2.0f - 2.0f -
1.0f; 1.0f;
} }
if (doc_size_levels_ != 0) { if (doc_size_levels_ != 0) {
projection[offset0 + feature_size_ - 2] = doc_size_feature; projection[offset0 + feature_size_ - kDocSizeOffset] =
doc_size_feature;
}
if (add_first_cap_feature_ > 0.0f) {
if (text_distorter_->BernouilleSample(add_first_cap_feature_)) {
projection[offset0 + feature_size_ - kFirstCapOffset] =
first_cap ? 1.0 : -1.0;
} else {
projection[offset0 + feature_size_ - kFirstCapOffset] = 0.0;
}
}
if (add_all_caps_feature_ > 0.0f) {
if (text_distorter_->BernouilleSample(add_all_caps_feature_)) {
projection[offset0 + feature_size_ - kAllCapsOffset] =
all_caps ? 1.0 : -1.0;
} else {
projection[offset0 + feature_size_ - kAllCapsOffset] = 0.0;
}
} }
offset0 += feature_size_; offset0 += feature_size_;
} }
...@@ -227,6 +266,8 @@ class SequenceStringProjectionOp : public OpKernel { ...@@ -227,6 +266,8 @@ class SequenceStringProjectionOp : public OpKernel {
int word_novelty_bits_; int word_novelty_bits_;
int doc_size_levels_; int doc_size_levels_;
float word_novelty_offset_; float word_novelty_offset_;
float add_first_cap_feature_;
float add_all_caps_feature_;
}; };
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
...@@ -241,16 +282,19 @@ REGISTER_OP("SequenceStringProjection") ...@@ -241,16 +282,19 @@ REGISTER_OP("SequenceStringProjection")
.Attr("feature_size: int") .Attr("feature_size: int")
.Attr("distortion_probability: float = 0.0") .Attr("distortion_probability: float = 0.0")
.Attr("vocabulary: string = ''") .Attr("vocabulary: string = ''")
.Attr("hashtype: string = 'murmur'")
.Attr("max_splits: int = -1") .Attr("max_splits: int = -1")
.Attr("exclude_nonalphaspace_unicodes: bool = False") .Attr("exclude_nonalphaspace_unicodes: bool = False")
.Attr("add_bos_tag: bool = False") .Attr("add_bos_tag: bool = False")
.Attr("add_eos_tag: bool = True") .Attr("add_eos_tag: bool = True")
.Attr("add_first_cap_feature: float = 0.0")
.Attr("add_all_caps_feature: float = 0.0")
.Attr("word_novelty_bits: int = 0") .Attr("word_novelty_bits: int = 0")
.Attr("doc_size_levels: int = 0") .Attr("doc_size_levels: int = 0")
.Attr("split_on_space: bool = True") .Attr("split_on_space: bool = True")
.Attr("token_separators: string = ''") .Attr("token_separators: string = ''")
.Attr("normalize_repetition: bool = false") .Attr("normalize_repetition: bool = false")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
DimensionHandle size; DimensionHandle size;
int32 feature_size; int32 feature_size;
...@@ -285,10 +329,11 @@ Attribute(s): ...@@ -285,10 +329,11 @@ Attribute(s):
will be allowed in the input text before fingerprinting. Another way to will be allowed in the input text before fingerprinting. Another way to
say it is that the vocabulary is an optional character allowlist for the say it is that the vocabulary is an optional character allowlist for the
input text. It helps normalize the text. input text. It helps normalize the text.
- hashtype: Hashing method to use for projection.
- max_splits: Maximum number of tokens that are allowed. It helps restrict the - max_splits: Maximum number of tokens that are allowed. It helps restrict the
max token length of the projection output. When the value is -1 the op max token length of the projection output. When the value is -1 the op
does not restrict the number of tokens in the output. does not restrict the number of tokens in the output.
- exclude_nonalphaspace_unicodes: When set to true excludes unicodes that are - exclude_nonalphaspace_unicodes: When true excludes all unicodes that are
not alphabets or space character. This is multilingual. Though the effect not alphabets or space character. This is multilingual. Though the effect
of this flag can be achieved using vocabulary, the vocabulary will have to of this flag can be achieved using vocabulary, the vocabulary will have to
be very large for multilingual input. be very large for multilingual input.
...@@ -301,6 +346,12 @@ Attribute(s): ...@@ -301,6 +346,12 @@ Attribute(s):
output the document size in log scale. This is an experimental feature. output the document size in log scale. This is an experimental feature.
- split_on_space: When true tokenization is done on space segmentation. - split_on_space: When true tokenization is done on space segmentation.
Otherwise tokenization is done by segmenting on unicode boundary. Otherwise tokenization is done by segmenting on unicode boundary.
- add_first_cap_feature: Specifies the probability with which a feature to the
resulting projection tensor that helps discriminate if the input token is
Camel case will be added.
- add_all_caps_feature: Specifies the probability with which a feature to the
resulting projection tensor that helps discriminate if the input token is
ALLCAPS will be added.
Output(s): Output(s):
- projection: Floating point tensor with ternary values of shape - projection: Floating point tensor with ternary values of shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment