"docs/en/1_exist_data_model.md" did not exist on "672cf58df6379005ba5cb0d9b0d188dca1f1fe86"
Commit d66941ac authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by calberti
Browse files

Sync w TF r0.12 & Bazel 0.4.3, internal updates (#953)

parent efa4a6cf
...@@ -117,7 +117,7 @@ class LexiconBuilder : public OpKernel { ...@@ -117,7 +117,7 @@ class LexiconBuilder : public OpKernel {
tag_to_category.SetCategory(token.tag(), token.category()); tag_to_category.SetCategory(token.tag(), token.category());
// Add characters. // Add characters.
vector<tensorflow::StringPiece> char_sp; std::vector<tensorflow::StringPiece> char_sp;
SegmenterUtils::GetUTF8Chars(word, &char_sp); SegmenterUtils::GetUTF8Chars(word, &char_sp);
for (const auto &c : char_sp) { for (const auto &c : char_sp) {
const string c_str = c.ToString(); const string c_str = c.ToString();
......
...@@ -132,10 +132,10 @@ class MorphologyTransitionState : public ParserTransitionState { ...@@ -132,10 +132,10 @@ class MorphologyTransitionState : public ParserTransitionState {
private: private:
// Currently assigned morphological analysis for each token in this sentence. // Currently assigned morphological analysis for each token in this sentence.
vector<int> tag_; std::vector<int> tag_;
// Gold morphological analysis from the input document. // Gold morphological analysis from the input document.
vector<int> gold_tag_; std::vector<int> gold_tag_;
// Tag map used for conversions between integer and string representations // Tag map used for conversions between integer and string representations
// part of speech tags. Not owned. // part of speech tags. Not owned.
......
...@@ -69,7 +69,7 @@ void MorphologyLabelSet::Write(ProtoRecordWriter *writer) const { ...@@ -69,7 +69,7 @@ void MorphologyLabelSet::Write(ProtoRecordWriter *writer) const {
} }
string MorphologyLabelSet::StringForMatch(const TokenMorphology &morph) const { string MorphologyLabelSet::StringForMatch(const TokenMorphology &morph) const {
vector<string> attributes; std::vector<string> attributes;
for (const auto &a : morph.attribute()) { for (const auto &a : morph.attribute()) {
attributes.push_back( attributes.push_back(
tensorflow::strings::StrCat(a.name(), kSeparator, a.value())); tensorflow::strings::StrCat(a.name(), kSeparator, a.value()));
...@@ -80,7 +80,7 @@ string MorphologyLabelSet::StringForMatch(const TokenMorphology &morph) const { ...@@ -80,7 +80,7 @@ string MorphologyLabelSet::StringForMatch(const TokenMorphology &morph) const {
string FullLabelFeatureType::GetFeatureValueName(FeatureValue value) const { string FullLabelFeatureType::GetFeatureValueName(FeatureValue value) const {
const TokenMorphology &morph = label_set_->Lookup(value); const TokenMorphology &morph = label_set_->Lookup(value);
vector<string> attributes; std::vector<string> attributes;
for (const auto &a : morph.attribute()) { for (const auto &a : morph.attribute()) {
attributes.push_back(tensorflow::strings::StrCat(a.name(), ":", a.value())); attributes.push_back(tensorflow::strings::StrCat(a.name(), ":", a.value()));
} }
......
...@@ -70,7 +70,7 @@ class MorphologyLabelSet { ...@@ -70,7 +70,7 @@ class MorphologyLabelSet {
// defined as follows: // defined as follows:
// //
// a == b iff the set of attribute pairs (attribute, value) is identical. // a == b iff the set of attribute pairs (attribute, value) is identical.
vector<TokenMorphology> label_set_; std::vector<TokenMorphology> label_set_;
// Because protocol buffer equality is complicated, we implement our own // Because protocol buffer equality is complicated, we implement our own
// equality operator based on strings. This unordered_map allows us to do the // equality operator based on strings. This unordered_map allows us to do the
......
...@@ -24,11 +24,12 @@ limitations under the License. ...@@ -24,11 +24,12 @@ limitations under the License.
namespace syntaxnet { namespace syntaxnet {
// Registry for the parser feature functions. // Registry for the parser feature functions.
REGISTER_CLASS_REGISTRY("parser feature function", ParserFeatureFunction); REGISTER_SYNTAXNET_CLASS_REGISTRY("parser feature function",
ParserFeatureFunction);
// Registry for the parser state + token index feature functions. // Registry for the parser state + token index feature functions.
REGISTER_CLASS_REGISTRY("parser+index feature function", REGISTER_SYNTAXNET_CLASS_REGISTRY("parser+index feature function",
ParserIndexFeatureFunction); ParserIndexFeatureFunction);
RootFeatureType::RootFeatureType(const string &name, RootFeatureType::RootFeatureType(const string &name,
const FeatureType &wrapped_type, const FeatureType &wrapped_type,
...@@ -228,4 +229,29 @@ class ParserTokenFeatureFunction : public NestedFeatureFunction< ...@@ -228,4 +229,29 @@ class ParserTokenFeatureFunction : public NestedFeatureFunction<
REGISTER_PARSER_IDX_FEATURE_FUNCTION("token", REGISTER_PARSER_IDX_FEATURE_FUNCTION("token",
ParserTokenFeatureFunction); ParserTokenFeatureFunction);
// Parser feature that always fetches the focus (position) of the token.
class FocusFeatureFunction : public ParserIndexFeatureFunction {
public:
// Initializes the feature function.
void Init(TaskContext *context) override {
// Note: this feature can return up to N values, where N is the length of
// the input sentence. Here, we give the arbitrary number 100 since it
// is not used.
set_feature_type(new NumericFeatureType(name(), 100));
}
void Evaluate(const WorkspaceSet &workspaces, const ParserState &object,
int focus, FeatureVector *result) const override {
FeatureValue value = focus;
result->add(feature_type(), value);
}
FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
int focus, const FeatureVector *result) const override {
return focus;
}
};
REGISTER_PARSER_IDX_FEATURE_FUNCTION("focus", FocusFeatureFunction);
} // namespace syntaxnet } // namespace syntaxnet
...@@ -48,10 +48,11 @@ typedef FeatureFunction<ParserState, int> ParserIndexFeatureFunction; ...@@ -48,10 +48,11 @@ typedef FeatureFunction<ParserState, int> ParserIndexFeatureFunction;
// Utilities to register the two types of parser features. // Utilities to register the two types of parser features.
#define REGISTER_PARSER_FEATURE_FUNCTION(name, component) \ #define REGISTER_PARSER_FEATURE_FUNCTION(name, component) \
REGISTER_FEATURE_FUNCTION(ParserFeatureFunction, name, component) REGISTER_SYNTAXNET_FEATURE_FUNCTION(ParserFeatureFunction, name, component)
#define REGISTER_PARSER_IDX_FEATURE_FUNCTION(name, component) \ #define REGISTER_PARSER_IDX_FEATURE_FUNCTION(name, component) \
REGISTER_FEATURE_FUNCTION(ParserIndexFeatureFunction, name, component) REGISTER_SYNTAXNET_FEATURE_FUNCTION(ParserIndexFeatureFunction, name, \
component)
// Alias for locator type that takes a parser state, and produces a focus // Alias for locator type that takes a parser state, and produces a focus
// integer that can be used on nested ParserIndexFeature objects. // integer that can be used on nested ParserIndexFeature objects.
......
...@@ -210,14 +210,14 @@ class ParserState { ...@@ -210,14 +210,14 @@ class ParserState {
int next_; int next_;
// Parse stack of partially processed tokens. // Parse stack of partially processed tokens.
vector<int> stack_; std::vector<int> stack_;
// List of head positions for the (partial) dependency tree. // List of head positions for the (partial) dependency tree.
vector<int> head_; std::vector<int> head_;
// List of dependency relation labels describing the (partial) dependency // List of dependency relation labels describing the (partial) dependency
// tree. // tree.
vector<int> label_; std::vector<int> label_;
// Score of the parser state. // Score of the parser state.
double score_ = 0.0; double score_ = 0.0;
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
# This test trains a parser on a small dataset, then runs it in greedy mode and # This test trains a parser on a small dataset, then runs it in greedy mode and
# in structured mode with beam 1, and checks that the result is identical. # in structured mode with beam 1, and checks that the result is identical.
set -eux set -eux
BINDIR=$TEST_SRCDIR/$TEST_WORKSPACE/syntaxnet BINDIR=$TEST_SRCDIR/$TEST_WORKSPACE/syntaxnet
......
...@@ -20,7 +20,7 @@ limitations under the License. ...@@ -20,7 +20,7 @@ limitations under the License.
namespace syntaxnet { namespace syntaxnet {
// Transition system registry. // Transition system registry.
REGISTER_CLASS_REGISTRY("transition system", ParserTransitionSystem); REGISTER_SYNTAXNET_CLASS_REGISTRY("transition system", ParserTransitionSystem);
void ParserTransitionSystem::PerformAction(ParserAction action, void ParserTransitionSystem::PerformAction(ParserAction action,
ParserState *state) const { ParserState *state) const {
......
...@@ -118,7 +118,7 @@ class ParserTransitionSystem ...@@ -118,7 +118,7 @@ class ParserTransitionSystem
// Returns all next gold actions for the parser during training using the // Returns all next gold actions for the parser during training using the
// dependency relations found in the underlying annotated sentence. // dependency relations found in the underlying annotated sentence.
virtual void GetAllNextGoldActions(const ParserState &state, virtual void GetAllNextGoldActions(const ParserState &state,
vector<ParserAction> *actions) const { std::vector<ParserAction> *actions) const {
ParserAction action = GetNextGoldAction(state); ParserAction action = GetNextGoldAction(state);
*actions = {action}; *actions = {action};
} }
...@@ -201,7 +201,7 @@ class ParserTransitionSystem ...@@ -201,7 +201,7 @@ class ParserTransitionSystem
}; };
#define REGISTER_TRANSITION_SYSTEM(type, component) \ #define REGISTER_TRANSITION_SYSTEM(type, component) \
REGISTER_CLASS_COMPONENT(ParserTransitionSystem, type, component) REGISTER_SYNTAXNET_CLASS_COMPONENT(ParserTransitionSystem, type, component)
} // namespace syntaxnet } // namespace syntaxnet
......
...@@ -88,13 +88,13 @@ bool PopulateTestInputs::Populate( ...@@ -88,13 +88,13 @@ bool PopulateTestInputs::Populate(
PopulateTestInputs::Create PopulateTestInputs::CreateTFMapFromDocumentTokens( PopulateTestInputs::Create PopulateTestInputs::CreateTFMapFromDocumentTokens(
const Sentence &document, const Sentence &document,
std::function<vector<string>(const Token &)> token2str) { std::function<std::vector<string>(const Token &)> token2str) {
return [document, token2str](TaskInput *input) { return [document, token2str](TaskInput *input) {
TermFrequencyMap map; TermFrequencyMap map;
// Build and write the dummy term frequency map. // Build and write the dummy term frequency map.
for (const Token &token : document.token()) { for (const Token &token : document.token()) {
vector<string> strings_for_token = token2str(token); std::vector<string> strings_for_token = token2str(token);
for (const string &s : strings_for_token) map.Increment(s); for (const string &s : strings_for_token) map.Increment(s);
} }
string file_name = AddPart(input, "text", ""); string file_name = AddPart(input, "text", "");
...@@ -116,22 +116,22 @@ PopulateTestInputs::Create PopulateTestInputs::CreateTagToCategoryFromTokens( ...@@ -116,22 +116,22 @@ PopulateTestInputs::Create PopulateTestInputs::CreateTagToCategoryFromTokens(
}; };
} }
vector<string> PopulateTestInputs::TokenCategory(const Token &token) { std::vector<string> PopulateTestInputs::TokenCategory(const Token &token) {
if (token.has_category()) return {token.category()}; if (token.has_category()) return {token.category()};
return {}; return {};
} }
vector<string> PopulateTestInputs::TokenLabel(const Token &token) { std::vector<string> PopulateTestInputs::TokenLabel(const Token &token) {
if (token.has_label()) return {token.label()}; if (token.has_label()) return {token.label()};
return {}; return {};
} }
vector<string> PopulateTestInputs::TokenTag(const Token &token) { std::vector<string> PopulateTestInputs::TokenTag(const Token &token) {
if (token.has_tag()) return {token.tag()}; if (token.has_tag()) return {token.tag()};
return {}; return {};
} }
vector<string> PopulateTestInputs::TokenWord(const Token &token) { std::vector<string> PopulateTestInputs::TokenWord(const Token &token) {
if (token.has_word()) return {token.word()}; if (token.has_word()) return {token.word()};
return {}; return {};
} }
......
...@@ -130,17 +130,17 @@ class PopulateTestInputs { ...@@ -130,17 +130,17 @@ class PopulateTestInputs {
// then saved to FLAGS_test_tmpdir/name. // then saved to FLAGS_test_tmpdir/name.
static Create CreateTFMapFromDocumentTokens( static Create CreateTFMapFromDocumentTokens(
const Sentence &document, const Sentence &document,
std::function<vector<string>(const Token &)> token2str); std::function<std::vector<string>(const Token &)> token2str);
// Creates a StringToStringMap protocol buffer input that maps tags to // Creates a StringToStringMap protocol buffer input that maps tags to
// categories. Uses whatever mapping is present in the document. // categories. Uses whatever mapping is present in the document.
static Create CreateTagToCategoryFromTokens(const Sentence &document); static Create CreateTagToCategoryFromTokens(const Sentence &document);
// Default implementations for "token2str" above. // Default implementations for "token2str" above.
static vector<string> TokenCategory(const Token &token); static std::vector<string> TokenCategory(const Token &token);
static vector<string> TokenLabel(const Token &token); static std::vector<string> TokenLabel(const Token &token);
static vector<string> TokenTag(const Token &token); static std::vector<string> TokenTag(const Token &token);
static vector<string> TokenWord(const Token &token); static std::vector<string> TokenWord(const Token &token);
// Utility function. Sets the TaskInput->part() fields for a new input part. // Utility function. Sets the TaskInput->part() fields for a new input part.
// Returns the file name. // Returns the file name.
......
...@@ -131,7 +131,7 @@ class StdIn : public tensorflow::RandomAccessFile { ...@@ -131,7 +131,7 @@ class StdIn : public tensorflow::RandomAccessFile {
char *scratch) const { char *scratch) const {
memcpy(scratch, buffer_.data(), buffer_.size()); memcpy(scratch, buffer_.data(), buffer_.size());
buffer_ = buffer_.substr(n); buffer_ = buffer_.substr(n);
result->set(scratch, n); *result = tensorflow::StringPiece(scratch, n);
expected_offset_ += n; expected_offset_ += n;
} }
...@@ -161,7 +161,7 @@ class TextReader { ...@@ -161,7 +161,7 @@ class TextReader {
Sentence *Read() { Sentence *Read() {
// Skips emtpy sentences, e.g., blank lines at the beginning of a file or // Skips emtpy sentences, e.g., blank lines at the beginning of a file or
// commented out blocks. // commented out blocks.
vector<Sentence *> sentences; std::vector<Sentence *> sentences;
string key, value; string key, value;
while (sentences.empty() && format_->ReadRecord(buffer_.get(), &value)) { while (sentences.empty() && format_->ReadRecord(buffer_.get(), &value)) {
key = tensorflow::strings::StrCat(filename_, ":", sentence_count_); key = tensorflow::strings::StrCat(filename_, ":", sentence_count_);
......
...@@ -143,7 +143,7 @@ class ParsingReader : public OpKernel { ...@@ -143,7 +143,7 @@ class ParsingReader : public OpKernel {
} }
// Create the outputs for each feature space. // Create the outputs for each feature space.
vector<Tensor *> feature_outputs(features_->NumEmbeddings()); std::vector<Tensor *> feature_outputs(features_->NumEmbeddings());
for (size_t i = 0; i < feature_outputs.size(); ++i) { for (size_t i = 0; i < feature_outputs.size(); ++i) {
OP_REQUIRES_OK(context, context->allocate_output( OP_REQUIRES_OK(context, context->allocate_output(
i, TensorShape({sentence_batch_->size(), i, TensorShape({sentence_batch_->size(),
...@@ -399,7 +399,7 @@ class DecodedParseReader : public ParsingReader { ...@@ -399,7 +399,7 @@ class DecodedParseReader : public ParsingReader {
// pull from the back of the docids queue as long as the sentences have been // pull from the back of the docids queue as long as the sentences have been
// completely processed. If the next document has not been completely // completely processed. If the next document has not been completely
// processed yet, then the docid will not be found in 'sentence_map_'. // processed yet, then the docid will not be found in 'sentence_map_'.
vector<Sentence> sentences; std::vector<Sentence> sentences;
while (!docids_.empty() && while (!docids_.empty() &&
sentence_map_.find(docids_.back()) != sentence_map_.end()) { sentence_map_.find(docids_.back()) != sentence_map_.end()) {
sentences.emplace_back(sentence_map_[docids_.back()]); sentences.emplace_back(sentence_map_[docids_.back()]);
...@@ -427,7 +427,7 @@ class DecodedParseReader : public ParsingReader { ...@@ -427,7 +427,7 @@ class DecodedParseReader : public ParsingReader {
string scoring_type_; string scoring_type_;
mutable std::deque<string> docids_; mutable std::deque<string> docids_;
mutable map<string, Sentence> sentence_map_; mutable std::map<string, Sentence> sentence_map_;
TF_DISALLOW_COPY_AND_ASSIGN(DecodedParseReader); TF_DISALLOW_COPY_AND_ASSIGN(DecodedParseReader);
}; };
......
...@@ -28,11 +28,11 @@ limitations under the License. ...@@ -28,11 +28,11 @@ limitations under the License.
// }; // };
// //
// #define REGISTER_FUNCTION(type, component) // #define REGISTER_FUNCTION(type, component)
// REGISTER_INSTANCE_COMPONENT(Function, type, component); // REGISTER_SYNTAXNET_INSTANCE_COMPONENT(Function, type, component);
// //
// function.cc: // function.cc:
// //
// REGISTER_INSTANCE_REGISTRY("function", Function); // REGISTER_SYNTAXNET_INSTANCE_REGISTRY("function", Function);
// //
// class Cos : public Function { // class Cos : public Function {
// public: // public:
...@@ -218,22 +218,22 @@ class RegisterableInstance { ...@@ -218,22 +218,22 @@ class RegisterableInstance {
static Registry registry_; static Registry registry_;
}; };
#define REGISTER_CLASS_COMPONENT(base, type, component) \ #define REGISTER_SYNTAXNET_CLASS_COMPONENT(base, type, component) \
static base *__##component##__factory() { return new component; } \ static base *__##component##__factory() { return new component; } \
static base::Registry::Registrar __##component##__##registrar( \ static base::Registry::Registrar __##component##__##registrar( \
base::registry(), type, #component, __FILE__, __LINE__, \ base::registry(), type, #component, __FILE__, __LINE__, \
__##component##__factory) __##component##__factory)
#define REGISTER_CLASS_REGISTRY(type, classname) \ #define REGISTER_SYNTAXNET_CLASS_REGISTRY(type, classname) \
template <> \ template <> \
classname::Registry RegisterableClass<classname>::registry_ = { \ classname::Registry RegisterableClass<classname>::registry_ = { \
type, #classname, __FILE__, __LINE__, NULL} type, #classname, __FILE__, __LINE__, NULL}
#define REGISTER_INSTANCE_COMPONENT(base, type, component) \ #define REGISTER_SYNTAXNET_INSTANCE_COMPONENT(base, type, component) \
static base::Registry::Registrar __##component##__##registrar( \ static base::Registry::Registrar __##component##__##registrar( \
base::registry(), type, #component, __FILE__, __LINE__, new component) base::registry(), type, #component, __FILE__, __LINE__, new component)
#define REGISTER_INSTANCE_REGISTRY(type, classname) \ #define REGISTER_SYNTAXNET_INSTANCE_REGISTRY(type, classname) \
template <> \ template <> \
classname::Registry RegisterableInstance<classname>::registry_ = { \ classname::Registry RegisterableInstance<classname>::registry_ = { \
type, #classname, __FILE__, __LINE__, NULL} type, #classname, __FILE__, __LINE__, NULL}
......
...@@ -38,7 +38,7 @@ const std::unordered_set<int> SegmenterUtils::kBreakChars({ ...@@ -38,7 +38,7 @@ const std::unordered_set<int> SegmenterUtils::kBreakChars({
}); });
void SegmenterUtils::GetUTF8Chars(const string &text, void SegmenterUtils::GetUTF8Chars(const string &text,
vector<tensorflow::StringPiece> *chars) { std::vector<tensorflow::StringPiece> *chars) {
const char *start = text.c_str(); const char *start = text.c_str();
const char *end = text.c_str() + text.size(); const char *end = text.c_str() + text.size();
while (start < end) { while (start < end) {
...@@ -50,7 +50,7 @@ void SegmenterUtils::GetUTF8Chars(const string &text, ...@@ -50,7 +50,7 @@ void SegmenterUtils::GetUTF8Chars(const string &text,
void SegmenterUtils::SetCharsAsTokens( void SegmenterUtils::SetCharsAsTokens(
const string &text, const string &text,
const vector<tensorflow::StringPiece> &chars, const std::vector<tensorflow::StringPiece> &chars,
Sentence *sentence) { Sentence *sentence) {
sentence->clear_token(); sentence->clear_token();
sentence->set_text(text); sentence->set_text(text);
......
...@@ -32,15 +32,15 @@ class SegmenterUtils { ...@@ -32,15 +32,15 @@ class SegmenterUtils {
// Takes a text and convert it into a vector, where each element is a utf8 // Takes a text and convert it into a vector, where each element is a utf8
// character. // character.
static void GetUTF8Chars(const string &text, static void GetUTF8Chars(const string &text,
vector<tensorflow::StringPiece> *chars); std::vector<tensorflow::StringPiece> *chars);
// Sets tokens in the sentence so that each token is a single character. // Sets tokens in the sentence so that each token is a single character.
// Assigns the start/end byte offsets. // Assigns the start/end byte offsets.
// //
// If the sentence is not empty, the current tokens will be cleared. // If the sentence is not empty, the current tokens will be cleared.
static void SetCharsAsTokens(const string &text, static void SetCharsAsTokens(
const vector<tensorflow::StringPiece> &chars, const string &text, const std::vector<tensorflow::StringPiece> &chars,
Sentence *sentence); Sentence *sentence);
// Returns true for UTF-8 characters that cannot be 'real' tokens. This is // Returns true for UTF-8 characters that cannot be 'real' tokens. This is
// defined as any whitespace, line break or paragraph break. // defined as any whitespace, line break or paragraph break.
......
...@@ -59,9 +59,9 @@ static Sentence GetKoSentence() { ...@@ -59,9 +59,9 @@ static Sentence GetKoSentence() {
// Gets the start end bytes of the given chars in the given text. // Gets the start end bytes of the given chars in the given text.
static void GetStartEndBytes(const string &text, static void GetStartEndBytes(const string &text,
const vector<tensorflow::StringPiece> &chars, const std::vector<tensorflow::StringPiece> &chars,
vector<int> *starts, std::vector<int> *starts,
vector<int> *ends) { std::vector<int> *ends) {
SegmenterUtils segment_utils; SegmenterUtils segment_utils;
for (const tensorflow::StringPiece &c : chars) { for (const tensorflow::StringPiece &c : chars) {
int start; int end; int start; int end;
...@@ -75,14 +75,14 @@ static void GetStartEndBytes(const string &text, ...@@ -75,14 +75,14 @@ static void GetStartEndBytes(const string &text,
TEST(SegmenterUtilsTest, GetCharsTest) { TEST(SegmenterUtilsTest, GetCharsTest) {
// Create test sentence. // Create test sentence.
const Sentence sentence = GetKoSentence(); const Sentence sentence = GetKoSentence();
vector<tensorflow::StringPiece> chars; std::vector<tensorflow::StringPiece> chars;
SegmenterUtils::GetUTF8Chars(sentence.text(), &chars); SegmenterUtils::GetUTF8Chars(sentence.text(), &chars);
// Check the number of characters is correct. // Check the number of characters is correct.
CHECK_EQ(chars.size(), 12); CHECK_EQ(chars.size(), 12);
vector<int> starts; std::vector<int> starts;
vector<int> ends; std::vector<int> ends;
GetStartEndBytes(sentence.text(), chars, &starts, &ends); GetStartEndBytes(sentence.text(), chars, &starts, &ends);
// Check start positions. // Check start positions.
...@@ -118,12 +118,12 @@ TEST(SegmenterUtilsTest, GetCharsTest) { ...@@ -118,12 +118,12 @@ TEST(SegmenterUtilsTest, GetCharsTest) {
TEST(SegmenterUtilsTest, SetCharsAsTokensTest) { TEST(SegmenterUtilsTest, SetCharsAsTokensTest) {
// Create test sentence. // Create test sentence.
const Sentence sentence = GetKoSentence(); const Sentence sentence = GetKoSentence();
vector<tensorflow::StringPiece> chars; std::vector<tensorflow::StringPiece> chars;
SegmenterUtils segment_utils; SegmenterUtils segment_utils;
segment_utils.GetUTF8Chars(sentence.text(), &chars); segment_utils.GetUTF8Chars(sentence.text(), &chars);
vector<int> starts; std::vector<int> starts;
vector<int> ends; std::vector<int> ends;
GetStartEndBytes(sentence.text(), chars, &starts, &ends); GetStartEndBytes(sentence.text(), chars, &starts, &ends);
// Check that the new docs word, start and end positions are properly set. // Check that the new docs word, start and end positions are properly set.
......
...@@ -83,7 +83,8 @@ string TermFrequencyMapSetFeature::WorkspaceName() const { ...@@ -83,7 +83,8 @@ string TermFrequencyMapSetFeature::WorkspaceName() const {
} }
namespace { namespace {
void GetUTF8Chars(const string &word, vector<tensorflow::StringPiece> *chars) { void GetUTF8Chars(const string &word,
std::vector<tensorflow::StringPiece> *chars) {
UnicodeText text; UnicodeText text;
text.PointToUTF8(word.c_str(), word.size()); text.PointToUTF8(word.c_str(), word.size());
for (UnicodeText::const_iterator it = text.begin(); it != text.end(); ++it) { for (UnicodeText::const_iterator it = text.begin(); it != text.end(); ++it) {
...@@ -98,9 +99,10 @@ int UTF8FirstLetterNumBytes(const char *utf8_str) { ...@@ -98,9 +99,10 @@ int UTF8FirstLetterNumBytes(const char *utf8_str) {
} // namespace } // namespace
void CharNgram::GetTokenIndices(const Token &token, vector<int> *values) const { void CharNgram::GetTokenIndices(const Token &token,
std::vector<int> *values) const {
values->clear(); values->clear();
vector<tensorflow::StringPiece> char_sp; std::vector<tensorflow::StringPiece> char_sp;
if (use_terminators_) char_sp.push_back("^"); if (use_terminators_) char_sp.push_back("^");
GetUTF8Chars(token.word(), &char_sp); GetUTF8Chars(token.word(), &char_sp);
if (use_terminators_) char_sp.push_back("$"); if (use_terminators_) char_sp.push_back("$");
...@@ -121,7 +123,7 @@ void CharNgram::GetTokenIndices(const Token &token, vector<int> *values) const { ...@@ -121,7 +123,7 @@ void CharNgram::GetTokenIndices(const Token &token, vector<int> *values) const {
} }
void MorphologySet::GetTokenIndices(const Token &token, void MorphologySet::GetTokenIndices(const Token &token,
vector<int> *values) const { std::vector<int> *values) const {
values->clear(); values->clear();
const TokenMorphology &token_morphology = const TokenMorphology &token_morphology =
token.GetExtension(TokenMorphology::morphology); token.GetExtension(TokenMorphology::morphology);
...@@ -401,7 +403,8 @@ string AffixTableFeature::GetFeatureValueName(FeatureValue value) const { ...@@ -401,7 +403,8 @@ string AffixTableFeature::GetFeatureValueName(FeatureValue value) const {
} }
// Registry for the Sentence + token index feature functions. // Registry for the Sentence + token index feature functions.
REGISTER_CLASS_REGISTRY("sentence+index feature function", SentenceFeature); REGISTER_SYNTAXNET_CLASS_REGISTRY("sentence+index feature function",
SentenceFeature);
// Register the features defined in the header. // Register the features defined in the header.
REGISTER_SENTENCE_IDX_FEATURE("word", Word); REGISTER_SENTENCE_IDX_FEATURE("word", Word);
......
...@@ -110,7 +110,7 @@ class TokenLookupSetFeature : public SentenceFeature { ...@@ -110,7 +110,7 @@ class TokenLookupSetFeature : public SentenceFeature {
// feature value set. The index is relative to the start of the sentence. // feature value set. The index is relative to the start of the sentence.
virtual void LookupToken(const WorkspaceSet &workspaces, virtual void LookupToken(const WorkspaceSet &workspaces,
const Sentence &sentence, int index, const Sentence &sentence, int index,
vector<int> *values) const = 0; std::vector<int> *values) const = 0;
// Given a feature value, returns a string representation. // Given a feature value, returns a string representation.
virtual string GetFeatureValueName(int value) const = 0; virtual string GetFeatureValueName(int value) const = 0;
...@@ -137,7 +137,7 @@ class TokenLookupSetFeature : public SentenceFeature { ...@@ -137,7 +137,7 @@ class TokenLookupSetFeature : public SentenceFeature {
// Returns a pre-computed token value from the cache. This assumes the cache // Returns a pre-computed token value from the cache. This assumes the cache
// is populated. // is populated.
const vector<int> &GetCachedValueSet(const WorkspaceSet &workspaces, const std::vector<int> &GetCachedValueSet(const WorkspaceSet &workspaces,
const Sentence &sentence, const Sentence &sentence,
int focus) const { int focus) const {
// Do bounds checking on focus. // Do bounds checking on focus.
...@@ -152,7 +152,7 @@ class TokenLookupSetFeature : public SentenceFeature { ...@@ -152,7 +152,7 @@ class TokenLookupSetFeature : public SentenceFeature {
void Evaluate(const WorkspaceSet &workspaces, const Sentence &sentence, void Evaluate(const WorkspaceSet &workspaces, const Sentence &sentence,
int focus, FeatureVector *result) const override { int focus, FeatureVector *result) const override {
if (focus >= 0 && focus < sentence.token_size()) { if (focus >= 0 && focus < sentence.token_size()) {
const vector<int> &elements = const std::vector<int> &elements =
GetCachedValueSet(workspaces, sentence, focus); GetCachedValueSet(workspaces, sentence, focus);
for (auto &value : elements) { for (auto &value : elements) {
result->add(this->feature_type(), value); result->add(this->feature_type(), value);
...@@ -233,7 +233,7 @@ class TermFrequencyMapSetFeature : public TokenLookupSetFeature { ...@@ -233,7 +233,7 @@ class TermFrequencyMapSetFeature : public TokenLookupSetFeature {
// Returns index of raw word text. // Returns index of raw word text.
virtual void GetTokenIndices(const Token &token, virtual void GetTokenIndices(const Token &token,
vector<int> *values) const = 0; std::vector<int> *values) const = 0;
// Requests the resource inputs. // Requests the resource inputs.
void Setup(TaskContext *context) override; void Setup(TaskContext *context) override;
...@@ -261,7 +261,7 @@ class TermFrequencyMapSetFeature : public TokenLookupSetFeature { ...@@ -261,7 +261,7 @@ class TermFrequencyMapSetFeature : public TokenLookupSetFeature {
// Given a position in a sentence and workspaces, looks up the corresponding // Given a position in a sentence and workspaces, looks up the corresponding
// feature value set. The index is relative to the start of the sentence. // feature value set. The index is relative to the start of the sentence.
void LookupToken(const WorkspaceSet &workspaces, const Sentence &sentence, void LookupToken(const WorkspaceSet &workspaces, const Sentence &sentence,
int index, vector<int> *values) const override { int index, std::vector<int> *values) const override {
GetTokenIndices(sentence.token(index), values); GetTokenIndices(sentence.token(index), values);
} }
...@@ -376,7 +376,8 @@ class CharNgram : public TermFrequencyMapSetFeature { ...@@ -376,7 +376,8 @@ class CharNgram : public TermFrequencyMapSetFeature {
} }
// Returns index of raw word text. // Returns index of raw word text.
void GetTokenIndices(const Token &token, vector<int> *values) const override; void GetTokenIndices(const Token &token,
std::vector<int> *values) const override;
private: private:
// Size parameter (n) for the ngrams. // Size parameter (n) for the ngrams.
...@@ -401,7 +402,8 @@ class MorphologySet : public TermFrequencyMapSetFeature { ...@@ -401,7 +402,8 @@ class MorphologySet : public TermFrequencyMapSetFeature {
} }
// Returns index of raw word text. // Returns index of raw word text.
void GetTokenIndices(const Token &token, vector<int> *values) const override; void GetTokenIndices(const Token &token,
std::vector<int> *values) const override;
}; };
class LexicalCategoryFeature : public TokenLookupFeature { class LexicalCategoryFeature : public TokenLookupFeature {
...@@ -635,7 +637,7 @@ typedef FeatureExtractor<Sentence, int> SentenceExtractor; ...@@ -635,7 +637,7 @@ typedef FeatureExtractor<Sentence, int> SentenceExtractor;
// Utility to register the sentence_instance::Feature functions. // Utility to register the sentence_instance::Feature functions.
#define REGISTER_SENTENCE_IDX_FEATURE(name, type) \ #define REGISTER_SENTENCE_IDX_FEATURE(name, type) \
REGISTER_FEATURE_FUNCTION(SentenceFeature, name, type) REGISTER_SYNTAXNET_FEATURE_FUNCTION(SentenceFeature, name, type)
} // namespace syntaxnet } // namespace syntaxnet
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment