Commit 64675fc7 authored by calberti's avatar calberti Committed by GitHub
Browse files

New transition systems and features for syntaxnet (#301)

* Morpher and segmenter transition systems and new features (quotes, punctuation, capitalization, character ngrams, morphology attributes).
parent a591478c
......@@ -144,7 +144,7 @@ class StdIn : public tensorflow::RandomAccessFile {
// Reads sentence protos from a text file.
class TextReader {
public:
explicit TextReader(const TaskInput &input) {
explicit TextReader(const TaskInput &input, TaskContext *context) {
CHECK_EQ(input.record_format_size(), 1)
<< "TextReader only supports inputs with one record format: "
<< input.DebugString();
......@@ -153,6 +153,7 @@ class TextReader {
<< input.DebugString();
filename_ = TaskContext::InputFile(input);
format_.reset(DocumentFormat::Create(input.record_format(0)));
format_->Setup(context);
Reset();
}
......@@ -202,7 +203,7 @@ class TextReader {
// Writes sentence protos to a text conll file.
class TextWriter {
public:
explicit TextWriter(const TaskInput &input) {
explicit TextWriter(const TaskInput &input, TaskContext *context) {
CHECK_EQ(input.record_format_size(), 1)
<< "TextWriter only supports files with one record format: "
<< input.DebugString();
......@@ -211,6 +212,7 @@ class TextWriter {
<< input.DebugString();
filename_ = TaskContext::InputFile(input);
format_.reset(DocumentFormat::Create(input.record_format(0)));
format_->Setup(context);
if (filename_ != "-") {
TF_CHECK_OK(
tensorflow::Env::Default()->NewWritableFile(filename_, &file_));
......
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/segmenter_utils.h"
#include "util/utf8/unicodetext.h"
#include "util/utf8/unilib.h"
#include "util/utf8/unilib_utf8_utils.h"
namespace syntaxnet {
// Separators, code Zs from http://www.unicode.org/Public/UNIDATA/PropList.txt
// NB: This list is not necessarily exhaustive.
const std::unordered_set<int> SegmenterUtils::kBreakChars({
0x2028, // line separator
0x2029, // paragraph separator
0x0020, // space
0x00a0, // no-break space
0x1680, // Ogham space mark
0x180e, // Mongolian vowel separator
0x202f, // narrow no-break space
0x205f, // medium mathematical space
0x3000, // ideographic space
0xe5e5, // Google addition
0x2000, 0x2001, 0x2002, 0x2003, 0x2004, 0x2005, 0x2006, 0x2007, 0x2008,
0x2009, 0x200a
});
void SegmenterUtils::GetUTF8Chars(const string &text,
vector<tensorflow::StringPiece> *chars) {
const char *start = text.c_str();
const char *end = text.c_str() + text.size();
while (start < end) {
int char_length = UniLib::OneCharLen(start);
chars->emplace_back(start, char_length);
start += char_length;
}
}
void SegmenterUtils::SetCharsAsTokens(
const string &text,
const vector<tensorflow::StringPiece> &chars,
Sentence *sentence) {
sentence->clear_token();
sentence->set_text(text);
for (int i = 0; i < chars.size(); ++i) {
Token *tok = sentence->add_token();
tok->set_word(chars[i].ToString()); // NOLINT
int start_byte, end_byte;
GetCharStartEndBytes(text, chars[i], &start_byte, &end_byte);
tok->set_start(start_byte);
tok->set_end(end_byte);
}
}
bool SegmenterUtils::IsValidSegment(const Sentence &sentence,
const Token &token) {
// Check that the token is not empty, both by string and by bytes.
if (token.word().empty()) return false;
if (token.start() > token.end()) return false;
// Check token boudaries inside of text.
if (token.start() < 0) return false;
if (token.end() >= sentence.text().size()) return false;
// Check that token string is valid UTF8, by bytes.
const char s = sentence.text()[token.start()];
const char e = sentence.text()[token.end() + 1];
if (UniLib::IsTrailByte(s)) return false;
if (UniLib::IsTrailByte(e)) return false;
return true;
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SYNTAXNET_SEGMENTER_UTILS_H_
#define SYNTAXNET_SEGMENTER_UTILS_H_
#include <string>
#include <vector>
#include <unordered_set>
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "util/utf8/unicodetext.h"
namespace syntaxnet {
// A set of common convenience functions.
class SegmenterUtils {
public:
// Takes a text and convert it into a vector, where each element is a utf8
// character.
static void GetUTF8Chars(const string &text,
vector<tensorflow::StringPiece> *chars);
// Sets tokens in the sentence so that each token is a single character.
// Assigns the start/end byte offsets.
//
// If the sentence is not empty, the current tokens will be cleared.
static void SetCharsAsTokens(const string &text,
const vector<tensorflow::StringPiece> &chars,
Sentence *sentence);
// Returns true for UTF-8 characters that cannot be 'real' tokens. This is
// defined as any whitespace, line break or paragraph break.
static bool IsBreakChar(const string &word) {
if (word == "\n" || word == "\t") return true;
UnicodeText text;
text.PointToUTF8(word.c_str(), word.length());
CHECK_EQ(text.size(), 1);
return kBreakChars.find(*text.begin()) != kBreakChars.end();
}
// Returns the break level for the next token based on the current character.
static Token::BreakLevel BreakLevel(const string &word) {
UnicodeText text;
text.PointToUTF8(word.c_str(), word.length());
auto point = *text.begin();
if (word == "\n" || point == kLineSeparator) {
return Token::LINE_BREAK;
} else if (point == kParagraphSeparator) {
return Token::SENTENCE_BREAK; // No PARAGRAPH_BREAK in sentence proto.
} else if (word == "\t" || kBreakChars.find(point) != kBreakChars.end()) {
return Token::SPACE_BREAK;
}
return Token::NO_BREAK;
}
// Convenience function for computing start/end byte offsets of a character
// StringPiece relative to original text.
static void GetCharStartEndBytes(const string &text,
tensorflow::StringPiece c,
int *start,
int *end) {
*start = c.data() - text.data();
*end = *start + c.size() - 1;
}
// Returns true if this segment is a valid segment. Currently checks:
// 1) It is non-empty
// 2) It is valid UTF8
static bool IsValidSegment(const Sentence &sentence, const Token &token);
// Set for utf8 break characters.
static const std::unordered_set<int> kBreakChars;
static const int kLineSeparator = 0x2028;
static const int kParagraphSeparator = 0x2029;
};
} // namespace syntaxnet
#endif // SYNTAXNET_SEGMENTER_UTILS_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/segmenter_utils.h"
#include <string>
#include <vector>
#include "syntaxnet/char_properties.h"
#include "syntaxnet/sentence.pb.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
// Creates a Korean senence and also initializes the token field.
static Sentence GetKoSentence() {
Sentence sentence;
string text = "서울시는 2012년부터";
// Add tokens.
sentence.set_text(text);
Token *tok = sentence.add_token();
tok->set_word("서울시");
tok->set_start(0);
tok->set_end(8);
tok = sentence.add_token();
tok->set_word("는");
tok->set_start(9);
tok->set_end(11);
tok = sentence.add_token();
tok->set_word("2012");
tok->set_start(13);
tok->set_end(16);
tok = sentence.add_token();
tok->set_word("년");
tok->set_start(17);
tok->set_end(19);
tok = sentence.add_token();
tok->set_word("부터");
tok->set_start(20);
tok->set_end(25);
return sentence;
}
// Gets the start end bytes of the given chars in the given text.
static void GetStartEndBytes(const string &text,
const vector<tensorflow::StringPiece> &chars,
vector<int> *starts,
vector<int> *ends) {
SegmenterUtils segment_utils;
for (const tensorflow::StringPiece &c : chars) {
int start; int end;
segment_utils.GetCharStartEndBytes(text, c, &start, &end);
starts->push_back(start);
ends->push_back(end);
}
}
// Test the GetChars function.
TEST(SegmenterUtilsTest, GetCharsTest) {
// Create test sentence.
const Sentence sentence = GetKoSentence();
vector<tensorflow::StringPiece> chars;
SegmenterUtils::GetUTF8Chars(sentence.text(), &chars);
// Check the number of characters is correct.
CHECK_EQ(chars.size(), 12);
vector<int> starts;
vector<int> ends;
GetStartEndBytes(sentence.text(), chars, &starts, &ends);
// Check start positions.
CHECK_EQ(starts[0], 0);
CHECK_EQ(starts[1], 3);
CHECK_EQ(starts[2], 6);
CHECK_EQ(starts[3], 9);
CHECK_EQ(starts[4], 12);
CHECK_EQ(starts[5], 13);
CHECK_EQ(starts[6], 14);
CHECK_EQ(starts[7], 15);
CHECK_EQ(starts[8], 16);
CHECK_EQ(starts[9], 17);
CHECK_EQ(starts[10], 20);
CHECK_EQ(starts[11], 23);
// Check end positions.
CHECK_EQ(ends[0], 2);
CHECK_EQ(ends[1], 5);
CHECK_EQ(ends[2], 8);
CHECK_EQ(ends[3], 11);
CHECK_EQ(ends[4], 12);
CHECK_EQ(ends[5], 13);
CHECK_EQ(ends[6], 14);
CHECK_EQ(ends[7], 15);
CHECK_EQ(ends[8], 16);
CHECK_EQ(ends[9], 19);
CHECK_EQ(ends[10], 22);
CHECK_EQ(ends[11], 25);
}
// Test the SetCharsAsTokens function.
TEST(SegmenterUtilsTest, SetCharsAsTokensTest) {
// Create test sentence.
const Sentence sentence = GetKoSentence();
vector<tensorflow::StringPiece> chars;
SegmenterUtils segment_utils;
segment_utils.GetUTF8Chars(sentence.text(), &chars);
vector<int> starts;
vector<int> ends;
GetStartEndBytes(sentence.text(), chars, &starts, &ends);
// Check that the new docs word, start and end positions are properly set.
Sentence new_sentence;
segment_utils.SetCharsAsTokens(sentence.text(), chars, &new_sentence);
CHECK_EQ(new_sentence.token_size(), chars.size());
for (int t = 0; t < sentence.token_size(); ++t) {
CHECK_EQ(new_sentence.token(t).word(), chars[t]);
CHECK_EQ(new_sentence.token(t).start(), starts[t]);
CHECK_EQ(new_sentence.token(t).end(), ends[t]);
}
// Re-running should remove the old tokens.
segment_utils.SetCharsAsTokens(sentence.text(), chars, &new_sentence);
CHECK_EQ(new_sentence.token_size(), chars.size());
for (int t = 0; t < sentence.token_size(); ++t) {
CHECK_EQ(new_sentence.token(t).word(), chars[t]);
CHECK_EQ(new_sentence.token(t).start(), starts[t]);
CHECK_EQ(new_sentence.token(t).end(), ends[t]);
}
}
} // namespace syntaxnet
......@@ -59,3 +59,18 @@ message Token {
extensions 1000 to max;
}
// Stores information about the morphology of a token.
message TokenMorphology {
extend Token {
optional TokenMorphology morphology = 63949837;
}
// Morphology is represented by a set of attribute values.
message Attribute {
required string name = 1;
required string value = 2;
}
// This attribute field is designated to hold a single disambiguated analysis.
repeated Attribute attribute = 3;
};
......@@ -24,7 +24,7 @@ limitations under the License.
namespace syntaxnet {
void SentenceBatch::Init(TaskContext *context) {
reader_.reset(new TextReader(*context->GetInput(input_name_)));
reader_.reset(new TextReader(*context->GetInput(input_name_), context));
size_ = 0;
}
......
......@@ -14,9 +14,11 @@ limitations under the License.
==============================================================================*/
#include "syntaxnet/sentence_features.h"
#include "syntaxnet/char_properties.h"
#include "syntaxnet/registry.h"
#include "util/utf8/unicodetext.h"
#include "util/utf8/unilib.h"
#include "util/utf8/unilib_utf8_utils.h"
namespace syntaxnet {
......@@ -55,6 +57,83 @@ string TermFrequencyMapFeature::WorkspaceName() const {
min_freq_, max_num_terms_);
}
TermFrequencyMapSetFeature::~TermFrequencyMapSetFeature() {
if (term_map_ != nullptr) {
SharedStore::Release(term_map_);
term_map_ = nullptr;
}
}
void TermFrequencyMapSetFeature::Setup(TaskContext *context) {
context->GetInput(input_name_, "text", "");
}
void TermFrequencyMapSetFeature::Init(TaskContext *context) {
min_freq_ = GetIntParameter("min-freq", 0);
max_num_terms_ = GetIntParameter("max-num-terms", 0);
file_name_ = context->InputFile(*context->GetInput(input_name_));
term_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
file_name_, min_freq_, max_num_terms_);
TokenLookupSetFeature::Init(context);
}
string TermFrequencyMapSetFeature::WorkspaceName() const {
return SharedStoreUtils::CreateDefaultName(
"term-frequency-map-set", input_name_, min_freq_, max_num_terms_);
}
namespace {
void GetUTF8Chars(const string &word, vector<tensorflow::StringPiece> *chars) {
UnicodeText text;
text.PointToUTF8(word.c_str(), word.size());
for (UnicodeText::const_iterator it = text.begin(); it != text.end(); ++it) {
chars->push_back(tensorflow::StringPiece(it.utf8_data(), it.utf8_length()));
}
}
int UTF8FirstLetterNumBytes(const char *utf8_str) {
if (*utf8_str == '\0') return 0;
return UniLib::OneCharLen(utf8_str);
}
} // namespace
void CharNgram::GetTokenIndices(const Token &token, vector<int> *values) const {
values->clear();
vector<tensorflow::StringPiece> char_sp;
if (use_terminators_) char_sp.push_back("^");
GetUTF8Chars(token.word(), &char_sp);
if (use_terminators_) char_sp.push_back("$");
for (int start = 0; start < char_sp.size(); ++start) {
string char_ngram;
for (int index = 0;
index < max_char_ngram_length_ && start + index < char_sp.size();
++index) {
tensorflow::StringPiece c = char_sp[start + index];
if (c == " ") break; // Never add char ngrams containing spaces.
tensorflow::strings::StrAppend(&char_ngram, c);
int value = LookupIndex(char_ngram);
if (value != -1) { // Skip unknown values.
values->push_back(value);
}
}
}
}
void MorphologySet::GetTokenIndices(const Token &token,
vector<int> *values) const {
values->clear();
const TokenMorphology &token_morphology =
token.GetExtension(TokenMorphology::morphology);
for (const TokenMorphology::Attribute &att : token_morphology.attribute()) {
int value =
LookupIndex(tensorflow::strings::StrCat(att.name(), "=", att.value()));
if (value != -1) { // Skip unknown values.
values->push_back(value);
}
}
}
string Hyphen::GetFeatureValueName(FeatureValue value) const {
switch (value) {
case NO_HYPHEN:
......@@ -70,6 +149,152 @@ FeatureValue Hyphen::ComputeValue(const Token &token) const {
return (word.find('-') < word.length() ? HAS_HYPHEN : NO_HYPHEN);
}
void Capitalization::Setup(TaskContext *context) {
utf8_ = (GetParameter("utf8") == "true");
}
// Runs ComputeValue for each token in the sentence.
void Capitalization::Preprocess(WorkspaceSet *workspaces,
Sentence *sentence) const {
if (workspaces->Has<VectorIntWorkspace>(Workspace())) return;
VectorIntWorkspace *workspace =
new VectorIntWorkspace(sentence->token_size());
for (int i = 0; i < sentence->token_size(); ++i) {
const int value = ComputeValueWithFocus(sentence->token(i), i);
workspace->set_element(i, value);
}
workspaces->Set<VectorIntWorkspace>(Workspace(), workspace);
}
string Capitalization::GetFeatureValueName(FeatureValue value) const {
switch (value) {
case LOWERCASE:
return "LOWERCASE";
case UPPERCASE:
return "UPPERCASE";
case CAPITALIZED:
return "CAPITALIZED";
case CAPITALIZED_SENTENCE_INITIAL:
return "CAPITALIZED_SENTENCE_INITIAL";
case NON_ALPHABETIC:
return "NON_ALPHABETIC";
}
return "<INVALID>";
}
FeatureValue Capitalization::ComputeValueWithFocus(const Token &token,
int focus) const {
const string &word = token.word();
// Check whether there is an uppercase or lowercase character.
bool has_upper = false;
bool has_lower = false;
if (utf8_) {
LOG(FATAL) << "Not implemented.";
} else {
const char *str = word.c_str();
for (int i = 0; i < word.length(); ++i) {
const char c = str[i];
has_upper = (has_upper || (c >= 'A' && c <= 'Z'));
has_lower = (has_lower || (c >= 'a' && c <= 'z'));
}
}
// Compute simple values.
if (!has_upper && has_lower) return LOWERCASE;
if (has_upper && !has_lower) return UPPERCASE;
if (!has_upper && !has_lower) return NON_ALPHABETIC;
// Else has_upper && has_lower; a normal capitalized word. Check the break
// level to determine whether the capitalized word is sentence-initial.
const bool sentence_initial = (focus == 0);
return sentence_initial ? CAPITALIZED_SENTENCE_INITIAL : CAPITALIZED;
}
string PunctuationAmount::GetFeatureValueName(FeatureValue value) const {
switch (value) {
case NO_PUNCTUATION:
return "NO_PUNCTUATION";
case SOME_PUNCTUATION:
return "SOME_PUNCTUATION";
case ALL_PUNCTUATION:
return "ALL_PUNCTUATION";
}
return "<INVALID>";
}
FeatureValue PunctuationAmount::ComputeValue(const Token &token) const {
const string &word = token.word();
bool has_punctuation = false;
bool all_punctuation = true;
const char *start = word.c_str();
const char *end = word.c_str() + word.size();
while (start < end) {
int char_length = UTF8FirstLetterNumBytes(start);
bool char_is_punct = is_punctuation_or_symbol(start, char_length);
all_punctuation &= char_is_punct;
has_punctuation |= char_is_punct;
if (!all_punctuation && has_punctuation) return SOME_PUNCTUATION;
start += char_length;
}
if (!all_punctuation) return NO_PUNCTUATION;
return ALL_PUNCTUATION;
}
string Quote::GetFeatureValueName(FeatureValue value) const {
switch (value) {
case NO_QUOTE:
return "NO_QUOTE";
case OPEN_QUOTE:
return "OPEN_QUOTE";
case CLOSE_QUOTE:
return "CLOSE_QUOTE";
case UNKNOWN_QUOTE:
return "UNKNOWN_QUOTE";
}
return "<INVALID>";
}
FeatureValue Quote::ComputeValue(const Token &token) const {
const string &word = token.word();
// Penn Treebank open and close quotes are multi-character.
if (word == "``") return OPEN_QUOTE;
if (word == "''") return CLOSE_QUOTE;
if (word.length() == 1) {
int char_len = UTF8FirstLetterNumBytes(word.c_str());
bool is_open = is_open_quote(word.c_str(), char_len);
bool is_close = is_close_quote(word.c_str(), char_len);
if (is_open && !is_close) return OPEN_QUOTE;
if (is_close && !is_open) return CLOSE_QUOTE;
if (is_open && is_close) return UNKNOWN_QUOTE;
}
return NO_QUOTE;
}
void Quote::Preprocess(WorkspaceSet *workspaces, Sentence *sentence) const {
if (workspaces->Has<VectorIntWorkspace>(Workspace())) return;
VectorIntWorkspace *workspace =
new VectorIntWorkspace(sentence->token_size());
// For double quote ", it is unknown whether they are open or closed without
// looking at the prior tokens in the sentence. in_quote is true iff an odd
// number of " marks have been seen so far in the sentence (similar to the
// behavior of some tokenizers).
bool in_quote = false;
for (int i = 0; i < sentence->token_size(); ++i) {
int quote_type = ComputeValue(sentence->token(i));
if (quote_type == UNKNOWN_QUOTE) {
// Update based on in_quote and flip in_quote.
quote_type = in_quote ? CLOSE_QUOTE : OPEN_QUOTE;
in_quote = !in_quote;
}
workspace->set_element(i, quote_type);
}
workspaces->Set<VectorIntWorkspace>(Workspace(), workspace);
}
string Digit::GetFeatureValueName(FeatureValue value) const {
switch (value) {
case NO_DIGIT:
......@@ -130,8 +355,7 @@ static AffixTable *CreateAffixTable(const string &filename,
void AffixTableFeature::Setup(TaskContext *context) {
context->GetInput(input_name_, "recordio", "affix-table");
affix_length_ = GetIntParameter("length", 0);
CHECK_GE(affix_length_, 0)
<< "Length must be specified for affix preprocessor.";
CHECK_GE(affix_length_, 0) << "Length must be specified for affix feature.";
TokenLookupFeature::Setup(context);
}
......@@ -181,6 +405,7 @@ REGISTER_CLASS_REGISTRY("sentence+index feature function", SentenceFeature);
// Register the features defined in the header.
REGISTER_SENTENCE_IDX_FEATURE("word", Word);
REGISTER_SENTENCE_IDX_FEATURE("char", Char);
REGISTER_SENTENCE_IDX_FEATURE("lcword", LowercaseWord);
REGISTER_SENTENCE_IDX_FEATURE("tag", Tag);
REGISTER_SENTENCE_IDX_FEATURE("offset", Offset);
......@@ -188,5 +413,10 @@ REGISTER_SENTENCE_IDX_FEATURE("hyphen", Hyphen);
REGISTER_SENTENCE_IDX_FEATURE("digit", Digit);
REGISTER_SENTENCE_IDX_FEATURE("prefix", PrefixFeature);
REGISTER_SENTENCE_IDX_FEATURE("suffix", SuffixFeature);
REGISTER_SENTENCE_IDX_FEATURE("char-ngram", CharNgram);
REGISTER_SENTENCE_IDX_FEATURE("morphology-set", MorphologySet);
REGISTER_SENTENCE_IDX_FEATURE("capitalization", Capitalization);
REGISTER_SENTENCE_IDX_FEATURE("punctuation-amount", PunctuationAmount);
REGISTER_SENTENCE_IDX_FEATURE("quote", Quote);
} // namespace syntaxnet
......@@ -23,6 +23,7 @@ limitations under the License.
#include "syntaxnet/affix.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/feature_types.h"
#include "syntaxnet/segmenter_utils.h"
#include "syntaxnet/shared_store.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/workspace.h"
......@@ -85,6 +86,88 @@ class TokenLookupFeature : public SentenceFeature {
return workspaces.Get<VectorIntWorkspace>(workspace_).element(focus);
}
int Workspace() const { return workspace_; }
private:
int workspace_;
};
// A multi purpose specialization of the feature. Processes the tokens in a
// Sentence by looking up a value set for each token and storing that in
// a VectorVectorInt workspace. Given a set of base values of size Size(),
// reserves an extra value for unknown tokens.
class TokenLookupSetFeature : public SentenceFeature {
public:
void Init(TaskContext *context) override {
set_feature_type(new ResourceBasedFeatureType<TokenLookupSetFeature>(
name(), this, {{NumValues(), "<OUTSIDE>"}}));
}
// Number of unique values.
virtual int64 NumValues() const = 0;
// Given a position in a sentence and workspaces, looks up the corresponding
// feature value set. The index is relative to the start of the sentence.
virtual void LookupToken(const WorkspaceSet &workspaces,
const Sentence &sentence, int index,
vector<int> *values) const = 0;
// Given a feature value, returns a string representation.
virtual string GetFeatureValueName(int value) const = 0;
// Name of the shared workspace.
virtual string WorkspaceName() const = 0;
// TokenLookupSetFeatures use VectorVectorIntWorkspaces by default.
void RequestWorkspaces(WorkspaceRegistry *registry) override {
workspace_ = registry->Request<VectorVectorIntWorkspace>(WorkspaceName());
}
// Default preprocessing: looks up a value set for each token in the Sentence.
void Preprocess(WorkspaceSet *workspaces, Sentence *sentence) const override {
// Default preprocessing: lookup a value set for each token in the Sentence.
if (workspaces->Has<VectorVectorIntWorkspace>(workspace_)) return;
VectorVectorIntWorkspace *workspace =
new VectorVectorIntWorkspace(sentence->token_size());
for (int i = 0; i < sentence->token_size(); ++i) {
LookupToken(*workspaces, *sentence, i, workspace->mutable_elements(i));
}
workspaces->Set<VectorVectorIntWorkspace>(workspace_, workspace);
}
// Returns a pre-computed token value from the cache. This assumes the cache
// is populated.
const vector<int> &GetCachedValueSet(const WorkspaceSet &workspaces,
const Sentence &sentence,
int focus) const {
// Do bounds checking on focus.
CHECK_GE(focus, 0);
CHECK_LT(focus, sentence.token_size());
// Return value from cache.
return workspaces.Get<VectorVectorIntWorkspace>(workspace_).elements(focus);
}
// Adds any precomputed features at the given focus, if present.
void Evaluate(const WorkspaceSet &workspaces, const Sentence &sentence,
int focus, FeatureVector *result) const override {
if (focus >= 0 && focus < sentence.token_size()) {
const vector<int> &elements =
GetCachedValueSet(workspaces, sentence, focus);
for (auto &value : elements) {
result->add(this->feature_type(), value);
}
}
}
// Returns the precomputed value, or NumValues() for features outside
// the sentence.
FeatureValue Compute(const WorkspaceSet &workspaces, const Sentence &sentence,
int focus, const FeatureVector *result) const override {
if (focus < 0 || focus >= sentence.token_size()) return NumValues();
return workspaces.Get<VectorIntWorkspace>(workspace_).element(focus);
}
private:
int workspace_;
};
......@@ -134,6 +217,83 @@ class TermFrequencyMapFeature : public TokenLookupFeature {
int max_num_terms_;
};
// Specialization of the TokenLookupSetFeature class to use a TermFrequencyMap
// to perform the mapping. This takes two options: "min_freq" (discard tokens
// with less than this min frequency), and "max_num_terms" (only read in at most
// these terms.)
class TermFrequencyMapSetFeature : public TokenLookupSetFeature {
public:
// Initializes with an empty name, since we need the options to compute the
// actual workspace name.
explicit TermFrequencyMapSetFeature(const string &input_name)
: input_name_(input_name), min_freq_(0), max_num_terms_(0) {}
// Releases shared resources.
~TermFrequencyMapSetFeature() override;
// Returns index of raw word text.
virtual void GetTokenIndices(const Token &token,
vector<int> *values) const = 0;
// Requests the resource inputs.
void Setup(TaskContext *context) override;
// Obtains resources using the shared store. At this point options are known
// so the full name can be computed.
void Init(TaskContext *context) override;
// Number of unique values.
int64 NumValues() const override { return term_map_->Size(); }
// Special value for strings not in the map.
FeatureValue UnknownValue() const { return term_map_->Size(); }
// Gets pointer to the underlying map.
const TermFrequencyMap *term_map() const { return term_map_; }
// Returns the term index or the unknown value. Used inside GetTokenIndex()
// specializations for convenience.
int LookupIndex(const string &term) const {
return term_map_->LookupIndex(term, -1);
}
// Given a position in a sentence and workspaces, looks up the corresponding
// feature value set. The index is relative to the start of the sentence.
void LookupToken(const WorkspaceSet &workspaces, const Sentence &sentence,
int index, vector<int> *values) const override {
GetTokenIndices(sentence.token(index), values);
}
// Uses the TermFrequencyMap to lookup the string associated with a value.
string GetFeatureValueName(int value) const override {
if (value == UnknownValue()) return "<UNKNOWN>";
if (value >= 0 && value < NumValues()) {
return term_map_->GetTerm(value);
}
LOG(ERROR) << "Invalid feature value: " << value;
return "<INVALID>";
}
// Name of the shared workspace.
string WorkspaceName() const override;
private:
// Shortcut pointer to shared map. Not owned.
const TermFrequencyMap *term_map_ = nullptr;
// Name of the input for the term map.
string input_name_;
// Filename of the underlying resource.
string file_name_;
// Minimum frequency for term map.
int min_freq_;
// Maximum number of terms for term map.
int max_num_terms_;
};
class Word : public TermFrequencyMapFeature {
public:
Word() : TermFrequencyMapFeature("word-map") {}
......@@ -144,6 +304,36 @@ class Word : public TermFrequencyMapFeature {
}
};
class Char : public TermFrequencyMapFeature {
public:
Char() : TermFrequencyMapFeature("char-map") {}
FeatureValue ComputeValue(const Token &token) const override {
const string &form = token.word();
if (SegmenterUtils::IsBreakChar(form)) return BreakCharValue();
return term_map().LookupIndex(form, UnknownValue());
}
// Special value for breaks.
FeatureValue BreakCharValue() const { return term_map().Size(); }
// Special value for non-break strings not in the map.
FeatureValue UnknownValue() const { return term_map().Size() + 1; }
// Number of unique values.
int64 NumValues() const override { return term_map().Size() + 2; }
string GetFeatureValueName(FeatureValue value) const override {
if (value == BreakCharValue()) return "<BREAK_CHAR>";
if (value == UnknownValue()) return "<UNKNOWN>";
if (value >= 0 && value < term_map().Size()) {
return term_map().GetTerm(value);
}
LOG(ERROR) << "Invalid feature value: " << value;
return "<INVALID>";
}
};
class LowercaseWord : public TermFrequencyMapFeature {
public:
LowercaseWord() : TermFrequencyMapFeature("lc-word-map") {}
......@@ -172,6 +362,47 @@ class Label : public TermFrequencyMapFeature {
}
};
class CharNgram : public TermFrequencyMapSetFeature {
public:
CharNgram() : TermFrequencyMapSetFeature("char-ngram-map") {}
~CharNgram() override {}
void Setup(TaskContext *context) override {
TermFrequencyMapSetFeature::Setup(context);
max_char_ngram_length_ = context->Get("lexicon_max_char_ngram_length", 3);
use_terminators_ =
context->Get("lexicon_char_ngram_include_terminators", false);
}
// Returns index of raw word text.
void GetTokenIndices(const Token &token, vector<int> *values) const override;
private:
// Size parameter (n) for the ngrams.
int max_char_ngram_length_ = 3;
// Whether to pad the word with ^ and $ before extracting ngrams.
bool use_terminators_ = false;
};
class MorphologySet : public TermFrequencyMapSetFeature {
public:
MorphologySet() : TermFrequencyMapSetFeature("morphology-map") {}
~MorphologySet() override {}
void Setup(TaskContext *context) override {
TermFrequencyMapSetFeature::Setup(context);
}
int64 NumValues() const override {
return term_map()->Size() - 1;
}
// Returns index of raw word text.
void GetTokenIndices(const Token &token, vector<int> *values) const override;
};
class LexicalCategoryFeature : public TokenLookupFeature {
public:
LexicalCategoryFeature(const string &name, int cardinality)
......@@ -180,7 +411,7 @@ class LexicalCategoryFeature : public TokenLookupFeature {
FeatureValue NumValues() const override { return cardinality_; }
// Returns the identifier for the workspace for this preprocessor.
// Returns the identifier for the workspace for this feature.
string WorkspaceName() const override {
return tensorflow::strings::StrCat(name_, ":", cardinality_);
}
......@@ -193,7 +424,7 @@ class LexicalCategoryFeature : public TokenLookupFeature {
const int cardinality_;
};
// Preprocessor that computes whether a word has a hyphen or not.
// Feature that computes whether a word has a hyphen or not.
class Hyphen : public LexicalCategoryFeature {
public:
// Enumeration of values.
......@@ -213,7 +444,100 @@ class Hyphen : public LexicalCategoryFeature {
FeatureValue ComputeValue(const Token &token) const override;
};
// Preprocessor that computes whether a word has a hyphen or not.
// Feature that categorizes the capitalization of the word. If the option
// utf8=true is specified, lowercase and uppercase checks are done with UTF8
// compliant functions.
class Capitalization : public LexicalCategoryFeature {
public:
// Enumeration of values.
enum Category {
LOWERCASE = 0, // normal word
UPPERCASE = 1, // all-caps
CAPITALIZED = 2, // has one cap and one non-cap
CAPITALIZED_SENTENCE_INITIAL = 3, // same as above but sentence-initial
NON_ALPHABETIC = 4, // contains no alphabetic characters
CARDINALITY = 5,
};
// Default constructor.
Capitalization() : LexicalCategoryFeature("capitalization", CARDINALITY) {}
// Sets one of the options for the capitalization.
void Setup(TaskContext *context) override;
// Capitalization needs special preprocessing because token category can
// depend on whether the token is at the start of the sentence.
void Preprocess(WorkspaceSet *workspaces, Sentence *sentence) const override;
// Returns a string representation of the enum value.
string GetFeatureValueName(FeatureValue value) const override;
// Returns the category value for the token.
FeatureValue ComputeValue(const Token &token) const override {
LOG(FATAL) << "Capitalization should use ComputeValueWithFocus.";
return 0;
}
// Returns the category value for the token.
FeatureValue ComputeValueWithFocus(const Token &token, int focus) const;
private:
// Whether to use UTF8 compliant functions to check capitalization.
bool utf8_ = false;
};
// A feature for computing whether the focus token contains any punctuation
// for ternary features.
class PunctuationAmount : public LexicalCategoryFeature {
public:
// Enumeration of values.
enum Category {
NO_PUNCTUATION = 0,
SOME_PUNCTUATION = 1,
ALL_PUNCTUATION = 2,
CARDINALITY = 3,
};
// Default constructor.
PunctuationAmount()
: LexicalCategoryFeature("punctuation-amount", CARDINALITY) {}
// Returns a string representation of the enum value.
string GetFeatureValueName(FeatureValue value) const override;
// Returns the category value for the token.
FeatureValue ComputeValue(const Token &token) const override;
};
// A feature for a feature that returns whether the word is an open or
// close quotation mark, based on its relative position to other quotation marks
// in the sentence.
class Quote : public LexicalCategoryFeature {
public:
// Enumeration of values.
enum Category {
NO_QUOTE = 0,
OPEN_QUOTE = 1,
CLOSE_QUOTE = 2,
UNKNOWN_QUOTE = 3,
CARDINALITY = 4,
};
// Default constructor.
Quote() : LexicalCategoryFeature("quote", CARDINALITY) {}
// Returns a string representation of the enum value.
string GetFeatureValueName(FeatureValue value) const override;
// Returns the category value for the token.
FeatureValue ComputeValue(const Token &token) const override;
// Override preprocess to compute open and close quotes from prior context of
// the sentence.
void Preprocess(WorkspaceSet *workspaces, Sentence *instance) const override;
};
// Feature that computes whether a word has digits or not.
class Digit : public LexicalCategoryFeature {
public:
// Enumeration of values.
......@@ -234,9 +558,9 @@ class Digit : public LexicalCategoryFeature {
FeatureValue ComputeValue(const Token &token) const override;
};
// TokenLookupPreprocessor object to compute prefixes and suffixes of words. The
// TokenLookupFeature object to compute prefixes and suffixes of words. The
// AffixTable is stored in the SharedStore. This is very similar to the
// implementation of TermFrequencyMapPreprocessor, but using an AffixTable to
// implementation of TermFrequencyMapFeature, but using an AffixTable to
// perform the lookups. There are only two specializations, for prefixes and
// suffixes.
class AffixTableFeature : public TokenLookupFeature {
......
......@@ -26,6 +26,7 @@ limitations under the License.
#include "syntaxnet/utils.h"
#include "syntaxnet/workspace.h"
#include <gmock/gmock.h>
#include "tensorflow/core/platform/test.h"
using testing::UnorderedElementsAreArray;
......@@ -83,6 +84,27 @@ class SentenceFeaturesTest : public ::testing::Test {
return values;
}
// Adds an input to the task context.
void AddInputToContext(const string &name, const string &file_pattern,
const string &file_format,
const string &record_format) {
TaskInput *input = context_.GetInput(name);
TaskInput::Part *part = input->add_part();
part->set_file_pattern(file_pattern);
part->set_file_format(file_format);
part->set_record_format(record_format);
}
// Checks that a vector workspace is equal to a target vector.
void CheckVectorWorkspace(const VectorIntWorkspace &workspace,
vector<int> target) {
vector<int> src;
for (int i = 0; i < workspace.size(); ++i) {
src.push_back(workspace.element(i));
}
EXPECT_THAT(src, testing::ContainerEq(target));
}
Sentence sentence_;
WorkspaceSet workspaces_;
......@@ -99,13 +121,18 @@ class CommonSentenceFeaturesTest : public SentenceFeaturesTest {
: SentenceFeaturesTest(
"text: 'I saw a man with a telescope.' "
"token { word: 'I' start: 0 end: 0 tag: 'PRP' category: 'PRON'"
" head: 1 label: 'nsubj' break_level: NO_BREAK } "
" head: 1 label: 'nsubj' break_level: NO_BREAK } "
"token { word: 'saw' start: 2 end: 4 tag: 'VBD' category: 'VERB'"
" label: 'ROOT' break_level: SPACE_BREAK } "
" label: 'ROOT' break_level: SPACE_BREAK } "
"token { word: 'a' start: 6 end: 6 tag: 'DT' category: 'DET'"
" head: 3 label: 'det' break_level: SPACE_BREAK } "
" head: 3 label: 'det' break_level: SPACE_BREAK } "
"token { word: 'man' start: 8 end: 10 tag: 'NN' category: 'NOUN'"
" head: 1 label: 'dobj' break_level: SPACE_BREAK } "
" head: 1 label: 'dobj' break_level: SPACE_BREAK"
" [syntaxnet.TokenMorphology.morphology] { "
" attribute { name:'morph' value:'Sg' } "
" attribute { name:'morph' value:'Masc' } "
" } "
"} "
"token { word: 'with' start: 12 end: 15 tag: 'IN' category: 'ADP'"
" head: 1 label: 'prep' break_level: SPACE_BREAK } "
"token { word: 'a' start: 17 end: 17 tag: 'DT' category: 'DET'"
......@@ -152,4 +179,96 @@ TEST_F(CommonSentenceFeaturesTest, OffsetPlusTag) {
EXPECT_EQ("<OUTSIDE>", ExtractFeature(9));
}
TEST_F(CommonSentenceFeaturesTest, CharNgramFeature) {
TermFrequencyMap char_ngram_map;
char_ngram_map.Increment("a");
char_ngram_map.Increment("aw");
char_ngram_map.Increment("sa");
creators_.Add(
"char-ngram-map", "text", "",
[&char_ngram_map](const string &path) { char_ngram_map.Save(path); });
// Test that CharNgram works as expected.
PrepareFeature("char-ngram");
EXPECT_EQ("", utils::Join(ExtractMultiFeature(-1), ","));
EXPECT_EQ("", utils::Join(ExtractMultiFeature(0), ","));
EXPECT_EQ("sa,a,aw", utils::Join(ExtractMultiFeature(1), ","));
EXPECT_EQ("a", utils::Join(ExtractMultiFeature(2), ","));
EXPECT_EQ("a", utils::Join(ExtractMultiFeature(3), ","));
EXPECT_EQ("", utils::Join(ExtractMultiFeature(8), ","));
}
TEST_F(CommonSentenceFeaturesTest, MorphologySetFeature) {
TermFrequencyMap morphology_map;
morphology_map.Increment("morph=Sg");
morphology_map.Increment("morph=Sg");
morphology_map.Increment("morph=Masc");
morphology_map.Increment("morph=Masc");
morphology_map.Increment("morph=Pl");
creators_.Add(
"morphology-map", "text", "",
[&morphology_map](const string &path) { morphology_map.Save(path); });
// Test that CharNgram works as expected.
PrepareFeature("morphology-set");
EXPECT_EQ("", utils::Join(ExtractMultiFeature(-1), ","));
EXPECT_EQ("", utils::Join(ExtractMultiFeature(0), ","));
EXPECT_EQ("morph=Sg,morph=Masc", utils::Join(ExtractMultiFeature(3), ","));
}
TEST_F(CommonSentenceFeaturesTest, CapitalizationProcessesCorrectly) {
Capitalization feature;
feature.RequestWorkspaces(&registry_);
workspaces_.Reset(registry_);
feature.Preprocess(&workspaces_, &sentence_);
// Check the workspace contains what we expect.
EXPECT_TRUE(workspaces_.Has<VectorIntWorkspace>(feature.Workspace()));
const VectorIntWorkspace &workspace =
workspaces_.Get<VectorIntWorkspace>(feature.Workspace());
constexpr int UPPERCASE = Capitalization::UPPERCASE;
constexpr int LOWERCASE = Capitalization::LOWERCASE;
constexpr int NON_ALPHABETIC = Capitalization::NON_ALPHABETIC;
CheckVectorWorkspace(workspace,
{UPPERCASE, LOWERCASE, LOWERCASE, LOWERCASE, LOWERCASE,
LOWERCASE, LOWERCASE, NON_ALPHABETIC});
}
class CharFeatureTest : public SentenceFeaturesTest {
protected:
CharFeatureTest()
: SentenceFeaturesTest(
"text: '一 个 测 试 员 ' "
"token { word: '一' start: 0 end: 2 } "
"token { word: '个' start: 3 end: 5 } "
"token { word: '测' start: 6 end: 8 } "
"token { word: '试' start: 9 end: 11 } "
"token { word: '员' start: 12 end: 14 } "
"token { word: ' ' start: 15 end: 15 } "
"token { word: '\t' start: 16 end: 16 } ") {}
};
TEST_F(CharFeatureTest, CharFeature) {
TermFrequencyMap char_map;
char_map.Increment("一");
char_map.Increment("个");
char_map.Increment("试");
char_map.Increment("员");
creators_.Add(
"char-map", "text", "",
[&char_map](const string &path) { char_map.Save(path); });
// Test that Char works as expected.
PrepareFeature("char");
EXPECT_EQ("<OUTSIDE>", ExtractFeature(-1));
EXPECT_EQ("一", ExtractFeature(0));
EXPECT_EQ("个", ExtractFeature(1));
EXPECT_EQ("<UNKNOWN>", ExtractFeature(2)); // "测" is not in the char map.
EXPECT_EQ("试", ExtractFeature(3));
EXPECT_EQ("员", ExtractFeature(4));
EXPECT_EQ("<BREAK_CHAR>", ExtractFeature(5));
EXPECT_EQ("<BREAK_CHAR>", ExtractFeature(6));
EXPECT_EQ("<OUTSIDE>", ExtractFeature(7));
}
} // namespace syntaxnet
......@@ -25,8 +25,10 @@ limitations under the License.
#include <string>
#include "syntaxnet/parser_features.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/sentence_features.h"
#include "syntaxnet/shared_store.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/term_frequency_map.h"
......@@ -98,7 +100,9 @@ class TaggerTransitionState : public ParserTransitionState {
for (size_t i = 0; i < tag_.size(); ++i) {
Token *token = sentence->mutable_token(i);
token->set_tag(TagAsString(Tag(i)));
token->set_category(tag_to_category_->GetCategory(token->tag()));
if (tag_to_category_) {
token->set_category(tag_to_category_->GetCategory(token->tag()));
}
}
}
......@@ -146,6 +150,7 @@ class TaggerTransitionSystem : public ParserTransitionSystem {
// Determines tag map location.
void Setup(TaskContext *context) override {
input_tag_map_ = context->GetInput("tag-map", "text", "");
join_category_to_pos_ = context->GetBoolParameter("join_category_to_pos");
input_tag_to_category_ = context->GetInput("tag-to-category", "text", "");
}
......@@ -154,15 +159,21 @@ class TaggerTransitionSystem : public ParserTransitionSystem {
const string tag_map_path = TaskContext::InputFile(*input_tag_map_);
tag_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
tag_map_path, 0, 0);
const string tag_to_category_path =
TaskContext::InputFile(*input_tag_to_category_);
tag_to_category_ = SharedStoreUtils::GetWithDefaultName<TagToCategoryMap>(
tag_to_category_path);
if (!join_category_to_pos_) {
const string tag_to_category_path =
TaskContext::InputFile(*input_tag_to_category_);
tag_to_category_ = SharedStoreUtils::GetWithDefaultName<TagToCategoryMap>(
tag_to_category_path);
}
}
// The SHIFT action uses the same value as the corresponding action type.
static ParserAction ShiftAction(int tag) { return tag; }
// The tagger transition system doesn't look at the dependency tree, so it
// allows non-projective trees.
bool AllowsNonProjective() const override { return true; }
// Returns the number of action types.
int NumActionTypes() const override { return 1; }
......@@ -251,8 +262,32 @@ class TaggerTransitionSystem : public ParserTransitionSystem {
// Tag to category map. Owned through SharedStore.
const TagToCategoryMap *tag_to_category_ = nullptr;
bool join_category_to_pos_ = false;
};
REGISTER_TRANSITION_SYSTEM("tagger", TaggerTransitionSystem);
// Feature function for retrieving the tag assigned to a token by the tagger
// transition system.
class PredictedTagFeatureFunction
: public BasicParserSentenceFeatureFunction<Tag> {
public:
PredictedTagFeatureFunction() {}
// Gets the TaggerTransitionState from the parser state and reads the assigned
// tag at the focus index. Returns -1 if the focus is not within the sentence.
FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
int focus, const FeatureVector *result) const override {
if (focus < 0 || focus >= state.sentence().token_size()) return -1;
return static_cast<const TaggerTransitionState *>(state.transition_state())
->Tag(focus);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(PredictedTagFeatureFunction);
};
REGISTER_PARSER_IDX_FEATURE_FUNCTION("pred-tag", PredictedTagFeatureFunction);
} // namespace syntaxnet
......@@ -60,6 +60,12 @@ input {
file_pattern: 'OUTPATH/category-map'
}
}
input {
name: 'char-map'
Part {
file_pattern: 'OUTPATH/char-map'
}
}
input {
name: 'prefix-table'
Part {
......
......@@ -63,6 +63,11 @@ class CoNLLSyntaxFormat : public DocumentFormat {
public:
CoNLLSyntaxFormat() {}
void Setup(TaskContext *context) override {
join_category_to_pos_ = context->GetBoolParameter("join_category_to_pos");
add_pos_as_attribute_ = context->GetBoolParameter("add_pos_as_attribute");
}
// Reads up to the first empty line and returns false end of file is reached.
bool ReadRecord(tensorflow::io::InputBuffer *buffer,
string *record) override {
......@@ -121,6 +126,7 @@ class CoNLLSyntaxFormat : public DocumentFormat {
const string &word = fields[1];
const string &cpostag = fields[3];
const string &tag = fields[4];
const string &attributes = fields[5];
const int head = utils::ParseUsing<int>(fields[6], 0, utils::ParseInt32);
const string &label = fields[7];
......@@ -139,6 +145,9 @@ class CoNLLSyntaxFormat : public DocumentFormat {
if (!tag.empty()) token->set_tag(tag);
if (!cpostag.empty()) token->set_category(cpostag);
if (!label.empty()) token->set_label(label);
if (!attributes.empty()) AddMorphAttributes(attributes, token);
if (join_category_to_pos_) JoinCategoryToPos(token);
if (add_pos_as_attribute_) AddPosAsAttribute(token);
}
if (sentence->token_size() > 0) {
......@@ -158,16 +167,18 @@ class CoNLLSyntaxFormat : public DocumentFormat {
*key = sentence.docid();
vector<string> lines;
for (int i = 0; i < sentence.token_size(); ++i) {
Token token = sentence.token(i);
if (join_category_to_pos_) SplitCategoryFromPos(&token);
if (add_pos_as_attribute_) RemovePosFromAttributes(&token);
vector<string> fields(10);
fields[0] = tensorflow::strings::Printf("%d", i + 1);
fields[1] = sentence.token(i).word();
fields[1] = token.word();
fields[2] = "_";
fields[3] = sentence.token(i).category();
fields[4] = sentence.token(i).tag();
fields[5] = "_";
fields[6] =
tensorflow::strings::Printf("%d", sentence.token(i).head() + 1);
fields[7] = sentence.token(i).label();
fields[3] = token.category();
fields[4] = token.tag();
fields[5] = GetMorphAttributes(token);
fields[6] = tensorflow::strings::Printf("%d", token.head() + 1);
fields[7] = token.label();
fields[8] = "_";
fields[9] = "_";
lines.push_back(utils::Join(fields, "\t"));
......@@ -176,6 +187,95 @@ class CoNLLSyntaxFormat : public DocumentFormat {
}
private:
// Creates a TokenMorphology object out of a list of attribute values of the
// form: a1=v1|a2=v2|... or v1|v2|...
void AddMorphAttributes(const string &attributes, Token *token) {
TokenMorphology *morph =
token->MutableExtension(TokenMorphology::morphology);
vector<string> att_vals = utils::Split(attributes, '|');
for (int i = 0; i < att_vals.size(); ++i) {
vector<string> att_val = utils::Split(att_vals[i], '=');
CHECK_LE(att_val.size(), 2)
<< "Error parsing morphology features "
<< "column, must be of format "
<< "a1=v1|a2=v2|... or v1|v2|... <field>: " << attributes;
// Format is either:
// 1) a1=v1|a2=v2..., e.g., Czech CoNLL data, or,
// 2) v1|v2|..., e.g., German CoNLL data.
const pair<string, string> name_value =
att_val.size() == 2 ? std::make_pair(att_val[0], att_val[1])
: std::make_pair(att_val[0], "on");
// We currently don't expect an empty attribute value, but might have an
// empty attribute name due to data input errors.
if (name_value.second.empty()) {
LOG(WARNING) << "Invalid attributes string: " << attributes
<< " for token: " << token->ShortDebugString();
continue;
}
if (!name_value.first.empty()) {
TokenMorphology::Attribute *attribute = morph->add_attribute();
attribute->set_name(name_value.first);
attribute->set_value(name_value.second);
}
}
}
// Creates a list of attribute values of the form a1=v1|a2=v2|... or v1|v2|...
// from a TokenMorphology object.
string GetMorphAttributes(const Token &token) {
const TokenMorphology &morph =
token.GetExtension(TokenMorphology::morphology);
if (morph.attribute_size() == 0) return "_";
string attributes;
for (const TokenMorphology::Attribute &attribute : morph.attribute()) {
if (!attributes.empty()) tensorflow::strings::StrAppend(&attributes, "|");
tensorflow::strings::StrAppend(&attributes, attribute.name());
if (attribute.value() != "on") {
tensorflow::strings::StrAppend(&attributes, "=", attribute.value());
}
}
return attributes;
}
void JoinCategoryToPos(Token *token) {
token->set_tag(
tensorflow::strings::StrCat(token->category(), "++", token->tag()));
token->clear_category();
}
void SplitCategoryFromPos(Token *token) {
const string &tag = token->tag();
const size_t pos = tag.find("++");
if (pos != string::npos) {
token->set_category(tag.substr(0, pos));
token->set_tag(tag.substr(pos + 2));
}
}
void AddPosAsAttribute(Token *token) {
if (!token->tag().empty()) {
TokenMorphology *morph =
token->MutableExtension(TokenMorphology::morphology);
TokenMorphology::Attribute *attribute = morph->add_attribute();
attribute->set_name("fPOS");
attribute->set_value(token->tag());
}
}
void RemovePosFromAttributes(Token *token) {
// Assumes the "fPOS" attribute, if present, is the last one.
TokenMorphology *morph =
token->MutableExtension(TokenMorphology::morphology);
if (morph->attribute().rbegin()->name() == "fPOS") {
morph->mutable_attribute()->RemoveLast();
}
}
bool join_category_to_pos_ = false;
bool add_pos_as_attribute_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(CoNLLSyntaxFormat);
};
......
......@@ -62,7 +62,7 @@ string Join(const std::vector<T> &s, const char *sep) {
return result;
}
string JoinPath(std::initializer_list<StringPiece> paths);
string JoinPath(std::initializer_list<tensorflow::StringPiece> paths);
size_t RemoveLeadingWhitespace(tensorflow::StringPiece *text);
......@@ -165,6 +165,64 @@ class PunctuationUtil {
void NormalizeDigits(string *form);
// Helper type to mark missing c-tor argument types
// for Type's c-tor in LazyStaticPtr<Type, ...>.
struct NoArg {};
template <typename Type, typename Arg1 = NoArg, typename Arg2 = NoArg,
typename Arg3 = NoArg>
class LazyStaticPtr {
public:
typedef Type element_type; // per smart pointer convention
// Pretend to be a pointer to Type (never NULL due to on-demand creation):
Type &operator*() const { return *get(); }
Type *operator->() const { return get(); }
// Named accessor/initializer:
Type *get() const {
if (!ptr_) Initialize(this);
return ptr_;
}
public:
// All the data is public and LazyStaticPtr has no constructors so that we can
// initialize LazyStaticPtr objects with the "= { arg_value, ... }" syntax.
// Clients of LazyStaticPtr must not access the data members directly.
// Arguments for Type's c-tor
// (unused NoArg-typed arguments consume either no space, or 1 byte to
// ensure address uniqueness):
Arg1 arg1_;
Arg2 arg2_;
Arg3 arg3_;
// The object we create and show.
mutable Type *ptr_;
private:
template <typename A1, typename A2, typename A3>
static Type *Factory(const A1 &a1, const A2 &a2, const A3 &a3) {
return new Type(a1, a2, a3);
}
template <typename A1, typename A2>
static Type *Factory(const A1 &a1, const A2 &a2, NoArg a3) {
return new Type(a1, a2);
}
template <typename A1>
static Type *Factory(const A1 &a1, NoArg a2, NoArg a3) {
return new Type(a1);
}
static Type *Factory(NoArg a1, NoArg a2, NoArg a3) { return new Type(); }
static void Initialize(const LazyStaticPtr *lsp) {
lsp->ptr_ = Factory(lsp->arg1_, lsp->arg2_, lsp->arg3_);
}
};
} // namespace utils
} // namespace syntaxnet
......
......@@ -185,6 +185,8 @@ class VectorIntWorkspace : public Workspace {
// Sets the i'th element.
void set_element(int i, int value) { elements_[i] = value; }
int size() const { return elements_.size(); }
private:
// The enclosed vector.
vector<int> elements_;
......
......@@ -462,6 +462,12 @@ inline string UnicodeTextToUTF8(const UnicodeText& t) {
return string(t.utf8_data(), t.utf8_length());
}
// This template function declaration is used in defining arraysize.
// Note that the function doesn't need an implementation, as we only
// use its type.
template <typename T, size_t N>
char (&ArraySizeHelper(T (&array)[N]))[N];
#define arraysize(array) (sizeof(ArraySizeHelper(array)))
// For debugging. Return a string of integers, written in uppercase
// hex (%X), corresponding to the codepoints within the text. Each
......
......@@ -25,10 +25,6 @@
namespace {
template <typename T, size_t N>
char (&ArraySizeHelper(T (&array)[N]))[N];
#define arraysize(array) (sizeof(ArraySizeHelper(array)))
class UnicodeTextTest : public testing::Test {
protected:
UnicodeTextTest() : empty_text_() {
......
......@@ -21,6 +21,7 @@
// They are also exported from unilib.h for legacy reasons.
#include "syntaxnet/base.h"
#include "third_party/utf/utf.h"
namespace UniLib {
......@@ -32,6 +33,19 @@ inline bool IsValidCodepoint(char32 c) {
|| (c >= 0xE000 && c <= 0x10FFFF);
}
// Returns true if 'str' is the start of a structurally valid UTF-8
// sequence and is not a surrogate codepoint. Returns false if str.empty()
// or if str.length() < UniLib::OneCharLen(str[0]). Otherwise, this function
// will access 1-4 bytes of src, where n is UniLib::OneCharLen(src[0]).
inline bool IsUTF8ValidCodepoint(StringPiece str) {
char32 c;
int consumed;
// It's OK if str.length() > consumed.
return !str.empty()
&& isvalidcharntorune(str.data(), str.size(), &c, &consumed)
&& IsValidCodepoint(c);
}
// Returns the length (number of bytes) of the Unicode code point
// starting at src, based on inspecting just that one byte. This
// requires that src point to a well-formed UTF-8 string; the result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment