Commit 64675fc7 authored by calberti's avatar calberti Committed by GitHub
Browse files

New transition systems and features for syntaxnet (#301)

* Morpher and segmenter transition systems and new features (quotes, punctuation, capitalization, character ngrams, morphology attributes).
parent a591478c
...@@ -107,8 +107,8 @@ Bazel should complete reporting all tests passed. ...@@ -107,8 +107,8 @@ Bazel should complete reporting all tests passed.
You can also compile SyntaxNet in a [Docker](https://www.docker.com/what-docker) You can also compile SyntaxNet in a [Docker](https://www.docker.com/what-docker)
container using this [Dockerfile](Dockerfile). container using this [Dockerfile](Dockerfile).
**Note:** If you are running Docker on OSX, make sure that you have enough memory allocated **Note:** If you are running Docker on OSX, make sure that you have enough
for your Docker VM. memory allocated for your Docker VM.
## Getting Started ## Getting Started
...@@ -612,6 +612,7 @@ Original authors of the code in this package include (in alphabetical order): ...@@ -612,6 +612,7 @@ Original authors of the code in this package include (in alphabetical order):
* David Weiss * David Weiss
* Emily Pitler * Emily Pitler
* Greg Coppola * Greg Coppola
* Ji Ma
* Keith Hall * Keith Hall
* Kuzman Ganchev * Kuzman Ganchev
* Michael Collins * Michael Collins
......
...@@ -158,6 +158,31 @@ cc_library( ...@@ -158,6 +158,31 @@ cc_library(
], ],
) )
cc_library(
name = "char_properties",
srcs = ["char_properties.cc"],
hdrs = ["char_properties.h"],
deps = [
":registry",
":utils",
"//util/utf8:unicodetext",
],
alwayslink = 1,
)
cc_library(
name = "segmenter_utils",
srcs = ["segmenter_utils.cc"],
hdrs = ["segmenter_utils.h"],
deps = [
":base",
":char_properties",
":sentence_proto",
"//util/utf8:unicodetext",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "feature_extractor", name = "feature_extractor",
srcs = ["feature_extractor.cc"], srcs = ["feature_extractor.cc"],
...@@ -199,6 +224,7 @@ cc_library( ...@@ -199,6 +224,7 @@ cc_library(
":affix", ":affix",
":feature_extractor", ":feature_extractor",
":registry", ":registry",
":segmenter_utils",
], ],
) )
...@@ -250,25 +276,51 @@ cc_library( ...@@ -250,25 +276,51 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "morphology_label_set",
srcs = ["morphology_label_set.cc"],
hdrs = ["morphology_label_set.h"],
deps = [
":document_format",
":feature_extractor",
":proto_io",
":registry",
":sentence_proto",
":utils",
],
)
cc_library( cc_library(
name = "parser_transitions", name = "parser_transitions",
srcs = [ srcs = [
"arc_standard_transitions.cc", "arc_standard_transitions.cc",
"binary_segment_state.cc",
"binary_segment_transitions.cc",
"morpher_transitions.cc",
"parser_features.cc",
"parser_state.cc", "parser_state.cc",
"parser_transitions.cc", "parser_transitions.cc",
"tagger_transitions.cc", "tagger_transitions.cc",
], ],
hdrs = [ hdrs = [
"binary_segment_state.h",
"parser_features.h",
"parser_state.h", "parser_state.h",
"parser_transitions.h", "parser_transitions.h",
], ],
deps = [ deps = [
":affix",
":feature_extractor",
":kbest_syntax_proto", ":kbest_syntax_proto",
":morphology_label_set",
":registry", ":registry",
":segmenter_utils",
":sentence_features",
":sentence_proto", ":sentence_proto",
":shared_store", ":shared_store",
":task_context", ":task_context",
":term_frequency_map", ":term_frequency_map",
":workspace",
], ],
alwayslink = 1, alwayslink = 1,
) )
...@@ -288,30 +340,12 @@ cc_library( ...@@ -288,30 +340,12 @@ cc_library(
], ],
) )
cc_library(
name = "parser_features",
srcs = ["parser_features.cc"],
hdrs = ["parser_features.h"],
deps = [
":affix",
":feature_extractor",
":parser_transitions",
":registry",
":sentence_features",
":task_context",
":term_frequency_map",
":workspace",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "embedding_feature_extractor", name = "embedding_feature_extractor",
srcs = ["embedding_feature_extractor.cc"], srcs = ["embedding_feature_extractor.cc"],
hdrs = ["embedding_feature_extractor.h"], hdrs = ["embedding_feature_extractor.h"],
deps = [ deps = [
":feature_extractor", ":feature_extractor",
":parser_features",
":parser_transitions", ":parser_transitions",
":sparse_proto", ":sparse_proto",
":task_context", ":task_context",
...@@ -326,7 +360,6 @@ cc_library( ...@@ -326,7 +360,6 @@ cc_library(
deps = [ deps = [
":embedding_feature_extractor", ":embedding_feature_extractor",
":feature_extractor", ":feature_extractor",
":parser_features",
":parser_transitions", ":parser_transitions",
":sentence_proto", ":sentence_proto",
":sparse_proto", ":sparse_proto",
...@@ -344,7 +377,6 @@ cc_library( ...@@ -344,7 +377,6 @@ cc_library(
"reader_ops.cc", "reader_ops.cc",
], ],
deps = [ deps = [
":parser_features",
":parser_transitions", ":parser_transitions",
":sentence_batch", ":sentence_batch",
":sentence_proto", ":sentence_proto",
...@@ -360,7 +392,6 @@ cc_library( ...@@ -360,7 +392,6 @@ cc_library(
srcs = ["document_filters.cc"], srcs = ["document_filters.cc"],
deps = [ deps = [
":document_format", ":document_format",
":parser_features",
":parser_transitions", ":parser_transitions",
":sentence_batch", ":sentence_batch",
":sentence_proto", ":sentence_proto",
...@@ -376,8 +407,8 @@ cc_library( ...@@ -376,8 +407,8 @@ cc_library(
deps = [ deps = [
":dictionary_proto", ":dictionary_proto",
":document_format", ":document_format",
":parser_features",
":parser_transitions", ":parser_transitions",
":segmenter_utils",
":sentence_batch", ":sentence_batch",
":sentence_proto", ":sentence_proto",
":task_context", ":task_context",
...@@ -438,6 +469,18 @@ filegroup( ...@@ -438,6 +469,18 @@ filegroup(
srcs = glob(["models/parsey_mcparseface/*"]), srcs = glob(["models/parsey_mcparseface/*"]),
) )
cc_test(
name = "binary_segment_state_test",
size = "small",
srcs = ["binary_segment_state_test.cc"],
deps = [
":base",
":parser_transitions",
":term_frequency_map",
":test_main",
],
)
cc_test( cc_test(
name = "shared_store_test", name = "shared_store_test",
size = "small", size = "small",
...@@ -448,6 +491,26 @@ cc_test( ...@@ -448,6 +491,26 @@ cc_test(
], ],
) )
cc_test(
name = "char_properties_test",
srcs = ["char_properties_test.cc"],
deps = [
":char_properties",
":test_main",
],
)
cc_test(
name = "segmenter_utils_test",
srcs = ["segmenter_utils_test.cc"],
deps = [
":base",
":segmenter_utils",
":sentence_proto",
":test_main",
],
)
cc_test( cc_test(
name = "sentence_features_test", name = "sentence_features_test",
size = "medium", size = "medium",
...@@ -465,6 +528,15 @@ cc_test( ...@@ -465,6 +528,15 @@ cc_test(
], ],
) )
cc_test(
name = "morphology_label_set_test",
srcs = ["morphology_label_set_test.cc"],
deps = [
":morphology_label_set",
":test_main",
],
)
cc_test( cc_test(
name = "arc_standard_transitions_test", name = "arc_standard_transitions_test",
size = "small", size = "small",
...@@ -479,6 +551,17 @@ cc_test( ...@@ -479,6 +551,17 @@ cc_test(
], ],
) )
cc_test(
name = "binary_segment_transitions_test",
size = "small",
srcs = ["binary_segment_transitions_test.cc"],
deps = [
":parser_transitions",
":sentence_proto",
":test_main",
],
)
cc_test( cc_test(
name = "tagger_transitions_test", name = "tagger_transitions_test",
size = "small", size = "small",
...@@ -499,7 +582,6 @@ cc_test( ...@@ -499,7 +582,6 @@ cc_test(
srcs = ["parser_features_test.cc"], srcs = ["parser_features_test.cc"],
deps = [ deps = [
":feature_extractor", ":feature_extractor",
":parser_features",
":parser_transitions", ":parser_transitions",
":populate_test_inputs", ":populate_test_inputs",
":sentence_proto", ":sentence_proto",
......
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/binary_segment_state.h"
#include <string>
#include "syntaxnet/segmenter_utils.h"
#include "syntaxnet/sentence.pb.h"
namespace syntaxnet {
ParserTransitionState *BinarySegmentState::Clone() const {
return new BinarySegmentState();
}
string BinarySegmentState::ToString(const ParserState &state) const {
string str("[");
for (int i = NumStarts(state) - 1; i >=0; --i) {
int start = LastStart(i, state);
int end = 0;
if (i - 1 >= 0) {
end = LastStart(i - 1, state) - 1;
} else if (state.EndOfInput()) {
end = state.sentence().token_size() - 1;
} else {
end = state.Next() - 1;
}
for (int k = start; k <= end; ++k) {
str.append(state.GetToken(k).word());
}
if (i >= 1) str.append(" ");
}
str.append("] ");
for (int i = state.Next(); i < state.NumTokens(); ++i) {
str.append(state.GetToken(i).word());
}
return str;
}
void BinarySegmentState::AddParseToDocument(const ParserState &state,
bool rewrite_root_labels,
Sentence *sentence) const {
if (sentence->token_size() == 0) return;
vector<bool> is_starts(sentence->token_size(), false);
for (int i = 0; i < NumStarts(state); ++i) {
is_starts[LastStart(i, state)] = true;
}
// Break level of the current token is determined based on its previous token.
Token::BreakLevel break_level = Token::NO_BREAK;
bool is_first_token = true;
Sentence new_sentence;
for (int i = 0; i < sentence->token_size(); ++i) {
const Token &token = sentence->token(i);
const string &word = token.word();
bool is_break = SegmenterUtils::IsBreakChar(word);
if (is_starts[i] || is_first_token) {
if (!is_break) {
// The current character is the first char of a new token/word.
Token *new_token = new_sentence.add_token();
new_token->set_start(token.start());
new_token->set_end(token.end());
new_token->set_word(word);
// For the first token, keep the old break level to make sure that the
// number of sentences stays unchanged.
new_token->set_break_level(break_level);
is_first_token = false;
}
} else {
// Append the character to the previous token.
if (!is_break) {
int index = new_sentence.token_size() - 1;
auto *last_token = new_sentence.mutable_token(index);
last_token->mutable_word()->append(word);
last_token->set_end(token.end());
}
}
// Update break level. Note we do not introduce new sentences in the
// transition system, thus anything goes beyond line break would be reduced
// to line break.
break_level = is_break ? SegmenterUtils::BreakLevel(word) : Token::NO_BREAK;
if (break_level >= Token::LINE_BREAK) break_level = Token::LINE_BREAK;
}
sentence->mutable_token()->Swap(new_sentence.mutable_token());
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SYNTAXNET_BINARY_SEGMENT_STATE_H_
#define SYNTAXNET_BINARY_SEGMENT_STATE_H_
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
namespace syntaxnet {
class Sentence;
// Parser state for binary segmentation transition system. The input of the
// system is a sequence of utf8 characters that are to be segmented into tokens.
// The system contains two type of transitions/actions:
// -START: the token at input is the first character of a new word.
// -MERGE: the token at input is to be merged with the its previous token.
//
// A BinarySegmentState is used to store segmentation histories that can be used
// as features. In addition, it also provides the functionality to add
// segmentation results to the document. The function assumes that sentences in
// a document are processed in left-to-right order. See also the comments of
// the FinishDocument function for explaination.
//
// Note on spaces:
// Spaces, or more generally break-characters, should never be any part of a
// word, and the START/MERGE of spaces would be ignored. In addition, if a space
// starts a new word, then the actual first char of that word is the first
// non-space token following the space.
// Some examples:
// -chars: ' ' A B
// -tags: S M M
// -result: 'AB'
//
// -chars: A ' ' B
// -tags: S M M
// -result: 'AB'
//
// -chars: A ' ' B
// -tags: S S M
// -result: 'AB'
//
// -chars: A B ' '
// -tags: S S M
// -result: 'A', 'B'
class BinarySegmentState : public ParserTransitionState {
public:
ParserTransitionState *Clone() const override;
void Init(ParserState *state) override {}
// Returns the number of start tokens that have already been identified. In
// other words, number of start tokens between the first token of the sentence
// and state.Input(), with state.Input() excluded.
static int NumStarts(const ParserState &state) {
return state.StackSize();
}
// Returns the index of the k-th most recent start token.
static int LastStart(int k, const ParserState &state) {
DCHECK_GE(k, 0);
DCHECK_LT(k, NumStarts(state));
return state.Stack(k);
}
// Adds the token at given index as a new start token.
static void AddStart(int index, ParserState *state) {
state->Push(index);
}
// Adds segmentation results to the given sentence.
void AddParseToDocument(const ParserState &state,
bool rewrite_root_labels,
Sentence *sentence) const override;
// Whether a parsed token should be considered correct for evaluation.
bool IsTokenCorrect(const ParserState &state, int index) const override {
return true;
}
// Returns a human readable string representation of this state.
string ToString(const ParserState &state) const override;
};
} // namespace syntaxnet
#endif // SYNTAXNET_BINARY_SEGMENT_STATE_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/binary_segment_state.h"
#include <memory>
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
class BinarySegmentStateTest : public ::testing::Test {
protected:
void SetUp() override {
// Prepare a sentence.
const char *str_sentence = "text: '测试 的 句子' "
"token { word: '测' start: 0 end: 2 } "
"token { word: '试' start: 3 end: 5 } "
"token { word: ' ' start: 6 end: 6 } "
"token { word: '的' start: 7 end: 9 } "
"token { word: ' ' start: 10 end: 10 } "
"token { word: '句' start: 11 end: 13 } "
"token { word: '子' start: 14 end: 16 } ";
sentence_ = std::unique_ptr<Sentence>(new Sentence());
TextFormat::ParseFromString(str_sentence, sentence_.get());
}
// The test document, parse tree, and sentence.
std::unique_ptr<Sentence> sentence_;
TermFrequencyMap label_map_;
};
TEST_F(BinarySegmentStateTest, AddStartLastStartNumStartsTest) {
BinarySegmentState *segment_state = new BinarySegmentState();
ParserState state(sentence_.get(), segment_state, &label_map_);
// Test segment_state initialized with zero starts.
EXPECT_EQ(0, segment_state->NumStarts(state));
// Adding the first token as a start token.
segment_state->AddStart(0, &state);
ASSERT_EQ(1, segment_state->NumStarts(state));
EXPECT_EQ(0, segment_state->LastStart(0, state));
// Adding more starts.
segment_state->AddStart(2, &state);
segment_state->AddStart(3, &state);
segment_state->AddStart(4, &state);
segment_state->AddStart(5, &state);
ASSERT_EQ(5, segment_state->NumStarts(state));
EXPECT_EQ(5, segment_state->LastStart(0, state));
EXPECT_EQ(4, segment_state->LastStart(1, state));
EXPECT_EQ(3, segment_state->LastStart(2, state));
EXPECT_EQ(2, segment_state->LastStart(3, state));
EXPECT_EQ(0, segment_state->LastStart(4, state));
}
TEST_F(BinarySegmentStateTest, AddParseToDocumentTest) {
BinarySegmentState *segment_state = new BinarySegmentState();
ParserState state(sentence_.get(), segment_state, &label_map_);
// Test gold segmentation.
// 0 1 2 3 4 5 6
// 测 试 ' ' 的 ' ' 句 子
// S M S S S S M
segment_state->AddStart(0, &state);
segment_state->AddStart(2, &state);
segment_state->AddStart(3, &state);
segment_state->AddStart(4, &state);
segment_state->AddStart(5, &state);
Sentence sentence_with_annotation = *sentence_;
segment_state->AddParseToDocument(state, false, &sentence_with_annotation);
// Test the number of tokens as well as the start/end byte-offsets of each
// token.
ASSERT_EQ(3, sentence_with_annotation.token_size());
// The first token is 测试.
EXPECT_EQ(0, sentence_with_annotation.token(0).start());
EXPECT_EQ(5, sentence_with_annotation.token(0).end());
// The second token is 的.
EXPECT_EQ(7, sentence_with_annotation.token(1).start());
EXPECT_EQ(9, sentence_with_annotation.token(1).end());
// The third token is 句子.
EXPECT_EQ(11, sentence_with_annotation.token(2).start());
EXPECT_EQ(16, sentence_with_annotation.token(2).end());
// Test merge space to other tokens. Since spaces, or more generally break
// characters, should never be a part of any word, they are skipped no matter
// how they are tagged.
// 0 1 2 3 4 5 6
// 测 试 ' ' 的 ' ' 句 子
// S M M S M M M
while (!state.StackEmpty()) state.Pop();
segment_state->AddStart(0, &state);
segment_state->AddStart(3, &state);
sentence_with_annotation = *sentence_;
segment_state->AddParseToDocument(state, false, &sentence_with_annotation);
ASSERT_EQ(2, sentence_with_annotation.token_size());
// The first token is 测试. Note even a space is tagged as "merge", it is not
// attached to its previous word.
EXPECT_EQ(0, sentence_with_annotation.token(0).start());
EXPECT_EQ(5, sentence_with_annotation.token(0).end());
// The second token is 的句子.
EXPECT_EQ(7, sentence_with_annotation.token(1).start());
EXPECT_EQ(16, sentence_with_annotation.token(1).end());
// Test merge a token to space tokens. In such case, the current token would
// be merged to the first non-space token on its left side.
// 0 1 2 3 4 5 6
// 测 试 ' ' 的 ' ' 句 子
// S M S M S M M
while (!state.StackEmpty()) state.Pop();
segment_state->AddStart(0, &state);
segment_state->AddStart(2, &state);
segment_state->AddStart(4, &state);
sentence_with_annotation = *sentence_;
segment_state->AddParseToDocument(state, false, &sentence_with_annotation);
ASSERT_EQ(1, sentence_with_annotation.token_size());
EXPECT_EQ(0, sentence_with_annotation.token(0).start());
EXPECT_EQ(16, sentence_with_annotation.token(0).end());
}
TEST_F(BinarySegmentStateTest, SpaceDocumentTest) {
const char *str_sentence = "text: ' \t\t' "
"token { word: ' ' start: 0 end: 0 } "
"token { word: '\t' start: 1 end: 1 } "
"token { word: '\t' start: 2 end: 2 } ";
TextFormat::ParseFromString(str_sentence, sentence_.get());
BinarySegmentState *segment_state = new BinarySegmentState();
ParserState state(sentence_.get(), segment_state, &label_map_);
// Break-chars should always be skipped, no matter how they are tagged.
// 0 1 2
//' ' '\t' '\t'
// M M M
Sentence sentence_with_annotation = *sentence_;
segment_state->AddParseToDocument(state, false, &sentence_with_annotation);
ASSERT_EQ(0, sentence_with_annotation.token_size());
// 0 1 2
//' ' '\t' '\t'
// S S S
segment_state->AddStart(0, &state);
segment_state->AddStart(1, &state);
segment_state->AddStart(2, &state);
sentence_with_annotation = *sentence_;
segment_state->AddParseToDocument(state, false, &sentence_with_annotation);
ASSERT_EQ(0, sentence_with_annotation.token_size());
}
TEST_F(BinarySegmentStateTest, DocumentBeginWithSpaceTest) {
const char *str_sentence = "text: ' 空格' "
"token { word: ' ' start: 0 end: 0 } "
"token { word: '空' start: 1 end: 3 } "
"token { word: '格' start: 4 end: 6 } ";
TextFormat::ParseFromString(str_sentence, sentence_.get());
BinarySegmentState *segment_state = new BinarySegmentState();
ParserState state(sentence_.get(), segment_state, &label_map_);
// 0 1 2
//' ' 空 格
// M M M
Sentence sentence_with_annotation = *sentence_;
segment_state->AddParseToDocument(state, false, &sentence_with_annotation);
ASSERT_EQ(1, sentence_with_annotation.token_size());
// The first token is 空格.
EXPECT_EQ(1, sentence_with_annotation.token(0).start());
EXPECT_EQ(6, sentence_with_annotation.token(0).end());
// 0 1 2
//' ' 空 格
// S M M
while (!state.StackEmpty()) state.Pop();
segment_state->AddStart(0, &state);
sentence_with_annotation = *sentence_;
segment_state->AddParseToDocument(state, false, &sentence_with_annotation);
ASSERT_EQ(1, sentence_with_annotation.token_size());
// The first token is 空格.
EXPECT_EQ(1, sentence_with_annotation.token(0).start());
EXPECT_EQ(6, sentence_with_annotation.token(0).end());
}
TEST_F(BinarySegmentStateTest, EmptyDocumentTest) {
const char *str_sentence = "text: '' ";
TextFormat::ParseFromString(str_sentence, sentence_.get());
BinarySegmentState *segment_state = new BinarySegmentState();
ParserState state(sentence_.get(), segment_state, &label_map_);
Sentence sentence_with_annotation = *sentence_;
segment_state->AddParseToDocument(state, false, &sentence_with_annotation);
ASSERT_EQ(0, sentence_with_annotation.token_size());
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/binary_segment_state.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
namespace syntaxnet {
// Given an input of utf8 characters, the BinarySegmentTransitionSystem
// conducts word segmentation by performing one of the following two actions:
// -START: starts a new word with the token at state.input, and also advances
// the state.input.
// -MERGE: adds the token at state.input to its prevous word, and also advances
// state.input.
//
// Also see nlp/saft/components/segmentation/transition/binary-segment-state.h
// for examples on handling spaces.
class BinarySegmentTransitionSystem : public ParserTransitionSystem {
public:
BinarySegmentTransitionSystem() {}
ParserTransitionState *NewTransitionState(bool train_mode) const override {
return new BinarySegmentState();
}
// Action types for the segmentation-transition system.
enum ParserActionType {
START = 0,
MERGE = 1,
CARDINAL = 2
};
static int StartAction() { return 0; }
static int MergeAction() { return 1; }
// The system always starts a new word by default.
ParserAction GetDefaultAction(const ParserState &state) const override {
return START;
}
// Returns the number of action types.
int NumActionTypes() const override {
return CARDINAL;
}
// Returns the number of possible actions.
int NumActions(int num_labels) const override {
return CARDINAL;
}
// Returns the next gold action for a given state according to the underlying
// annotated sentence. The training data for the transition system is created
// by the binary-segmenter-data task. If a token's break_level is NO_BREAK,
// then it is a MERGE, START otherwise. The only exception is that the first
// token in a sentence for the transition sysytem is always a START.
ParserAction GetNextGoldAction(const ParserState &state) const override {
if (state.Next() == 0) return StartAction();
const Token &token = state.GetToken(state.Next());
return (token.break_level() != Token::NO_BREAK ?
StartAction() : MergeAction());
}
// Both START and MERGE can be applied to any tokens in the sentence.
bool IsAllowedAction(
ParserAction action, const ParserState &state) const override {
return true;
}
// Performs the specified action on a given parser state, without adding the
// action to the state's history.
void PerformActionWithoutHistory(
ParserAction action, ParserState *state) const override {
// Note when the action is less than 0, it is treated as a START.
if (action < 0 || action == StartAction()) {
MutableTransitionState(state)->AddStart(state->Next(), state);
}
state->Advance();
}
// Allows backoff to best allowable transition.
bool BackOffToBestAllowableTransition() const override { return true; }
// A state is a deterministic state iff no tokens have been consumed.
bool IsDeterministicState(const ParserState &state) const override {
return state.Next() == 0;
}
// For binary segmentation, a state is a final state iff all tokens have been
// consumed.
bool IsFinalState(const ParserState &state) const override {
return state.EndOfInput();
}
// Returns a string representation of a parser action.
string ActionAsString(
ParserAction action, const ParserState &state) const override {
return action == StartAction() ? "START" : "MERGE";
}
// Downcasts the TransitionState in ParserState to an BinarySegmentState.
static BinarySegmentState *MutableTransitionState(ParserState *state) {
return static_cast<BinarySegmentState *>(state->mutable_transition_state());
}
};
REGISTER_TRANSITION_SYSTEM("binary-segment-transitions",
BinarySegmentTransitionSystem);
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/binary_segment_state.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
class SegmentationTransitionTest : public ::testing::Test {
protected:
void SetUp() override {
transition_system_ = std::unique_ptr<ParserTransitionSystem>(
ParserTransitionSystem::Create("binary-segment-transitions"));
// Prepare a sentence.
const char *str_sentence = "text: '因为 有 这样' "
"token { word: '因' start: 0 end: 2 break_level: SPACE_BREAK } "
"token { word: '为' start: 3 end: 5 break_level: NO_BREAK } "
"token { word: ' ' start: 6 end: 6 break_level: SPACE_BREAK } "
"token { word: '有' start: 7 end: 9 break_level: SPACE_BREAK } "
"token { word: ' ' start: 10 end: 10 break_level: SPACE_BREAK } "
"token { word: '这' start: 11 end: 13 break_level: SPACE_BREAK } "
"token { word: '样' start: 14 end: 16 break_level: NO_BREAK } ";
sentence_ = std::unique_ptr<Sentence>(new Sentence());
TextFormat::ParseFromString(str_sentence, sentence_.get());
}
void CheckStarts(const ParserState &state, const vector<int> &target) {
ASSERT_EQ(state.StackSize(), target.size());
vector<int> starts;
for (int i = 0; i < state.StackSize(); ++i) {
EXPECT_EQ(state.Stack(i), target[i]);
}
}
// The test document, parse tree, and sentence with tags and partial parses.
std::unique_ptr<Sentence> sentence_;
std::unique_ptr<ParserTransitionSystem> transition_system_;
TermFrequencyMap label_map_;
};
TEST_F(SegmentationTransitionTest, GoldNextActionTest) {
BinarySegmentState *segment_state = static_cast<BinarySegmentState *>(
transition_system_->NewTransitionState(true));
ParserState state(sentence_.get(), segment_state, &label_map_);
// Do segmentation by following the gold actions.
while (transition_system_->IsFinalState(state) == false) {
ParserAction action = transition_system_->GetNextGoldAction(state);
transition_system_->PerformActionWithoutHistory(action, &state);
}
// Test STARTs.
CheckStarts(state, {5, 4, 3, 2, 0});
// Test the annotated tokens.
segment_state->AddParseToDocument(state, false, sentence_.get());
ASSERT_EQ(sentence_->token_size(), 3);
EXPECT_EQ(sentence_->token(0).word(), "因为");
EXPECT_EQ(sentence_->token(1).word(), "有");
EXPECT_EQ(sentence_->token(2).word(), "这样");
// Test start/end annotation of each token.
EXPECT_EQ(sentence_->token(0).start(), 0);
EXPECT_EQ(sentence_->token(0).end(), 5);
EXPECT_EQ(sentence_->token(1).start(), 7);
EXPECT_EQ(sentence_->token(1).end(), 9);
EXPECT_EQ(sentence_->token(2).start(), 11);
EXPECT_EQ(sentence_->token(2).end(), 16);
}
TEST_F(SegmentationTransitionTest, DefaultActionTest) {
BinarySegmentState *segment_state = static_cast<BinarySegmentState *>(
transition_system_->NewTransitionState(true));
ParserState state(sentence_.get(), segment_state, &label_map_);
// Do segmentation, tagging and parsing by following the gold actions.
while (transition_system_->IsFinalState(state) == false) {
ParserAction action = transition_system_->GetDefaultAction(state);
transition_system_->PerformActionWithoutHistory(action, &state);
}
// Every character should be START.
CheckStarts(state, {6, 5, 4, 3, 2, 1, 0});
// Every non-space character should be a word.
segment_state->AddParseToDocument(state, false, sentence_.get());
ASSERT_EQ(sentence_->token_size(), 5);
EXPECT_EQ(sentence_->token(0).word(), "因");
EXPECT_EQ(sentence_->token(1).word(), "为");
EXPECT_EQ(sentence_->token(2).word(), "有");
EXPECT_EQ(sentence_->token(3).word(), "这");
EXPECT_EQ(sentence_->token(4).word(), "样");
}
} // namespace syntaxnet
This diff is collapsed.
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// char_properties.h - define is_X() tests for various character properties
//
// Character properties can be defined in two ways:
//
// (1) Set-based:
//
// Enumerate the chars that have the property. Example:
//
// DEFINE_CHAR_PROPERTY_AS_SET(my_fave,
// RANGE('0', '9'),
// '\'',
// 0x00BF, // Spanish inverted question mark
// )
//
// Characters are expressed as Unicode code points; note that ascii codes
// are a subset. RANGE() specifies an inclusive range of code points.
//
// This defines two functions:
//
// bool is_my_fave(const char *str, int len)
// bool is_my_fave(int c)
//
// Each returns true for precisely the 12 characters specified above.
// Each takes a *single* UTf8 char as its argument -- the first expresses
// it as a char * and a length, the second as a Unicode code point.
// Please do not pass a string of multiple UTF8 chars to the first one.
//
// To make is_my_fave() externally accessible, put in your .h file:
//
// DECLARE_CHAR_PROPERTY(my_fave)
//
// (2) Function-based:
//
// Specify a function that assigns the desired chars to a CharProperty
// object. Example:
//
// DEFINE_CHAR_PROPERTY(my_other_fave, prop) {
// for (int i = '0'; i <= '9'; i += 2) {
// prop->AddChar(i);
// }
// prop->AddAsciiPredicate(&ispunct);
// prop->AddCharProperty("currency_symbol");
// }
//
// This defines a function of one arg: CharProperty *prop. The function
// calls various CharProperty methods to populate the prop. The last call
// above, AddCharProperty(), adds the chars from another char property
// ("currency_symbol").
//
// As in the set-based case, put a DECLARE_CHAR_PROPERTY(my_other_fave)
// in your .h if you want is_my_other_fave() to be externally accessible.
//
#ifndef SYNTAXNET_CHAR_PROPERTIES_H_
#define SYNTAXNET_CHAR_PROPERTIES_H_
#include <string> // for string
#include "syntaxnet/registry.h"
#include "syntaxnet/utils.h"
// =====================================================================
// Registry for accessing CharProperties by name
//
// This is for internal use by the CharProperty class and macros; callers
// should not use it explicitly.
//
namespace syntaxnet {
class CharProperty; // forward declaration
// Wrapper around a CharProperty, allowing it to be stored in a registry.
struct CharPropertyWrapper : RegisterableClass<CharPropertyWrapper> {
virtual ~CharPropertyWrapper() { }
virtual CharProperty *GetCharProperty() = 0;
};
#define REGISTER_CHAR_PROPERTY_WRAPPER(type, component) \
REGISTER_CLASS_COMPONENT(CharPropertyWrapper, type, component)
#define REGISTER_CHAR_PROPERTY(lsp, name) \
struct name##CharPropertyWrapper : public CharPropertyWrapper { \
CharProperty *GetCharProperty() { return lsp.get(); } \
}; \
REGISTER_CHAR_PROPERTY_WRAPPER(#name, name##CharPropertyWrapper)
// =====================================================================
// Macros for defining character properties
//
// Define is_X() functions to test whether a single UTF8 character has
// the 'X' char prop.
#define DEFINE_IS_X_CHAR_PROPERTY_FUNCTIONS(lsp, name) \
bool is_##name(const char *str, int len) { \
return lsp->HoldsFor(str, len); \
} \
bool is_##name(int c) { \
return lsp->HoldsFor(c); \
}
// Define a char property by enumerating the unicode char points,
// or RANGE()s thereof, for which it holds. Example:
//
// DEFINE_CHAR_PROPERTY_AS_SET(my_fave,
// 'q',
// RANGE('0', '9'),
// 0x20AB,
// )
//
// "..." is a GNU extension.
#define DEFINE_CHAR_PROPERTY_AS_SET(name, unicodes...) \
static const int k_##name##_unicodes[] = {unicodes}; \
static utils::LazyStaticPtr<CharProperty, const char *, const int *, size_t> \
name##_char_property = {#name, k_##name##_unicodes, \
arraysize(k_##name##_unicodes)}; \
REGISTER_CHAR_PROPERTY(name##_char_property, name); \
DEFINE_IS_X_CHAR_PROPERTY_FUNCTIONS(name##_char_property, name)
// Specify a range (inclusive) of Unicode character values.
// Example: RANGE('0', '9') specifies the 10 digits.
// For use as an element in a DEFINE_CHAR_PROPERTY_AS_SET() list.
static const int kPreUnicodeRange = -1;
static const int kPostUnicodeRange = -2;
#define RANGE(lower, upper) \
kPreUnicodeRange, lower, upper, kPostUnicodeRange
// A function to initialize a CharProperty.
typedef void CharPropertyInitializer(CharProperty *prop);
// Define a char property by specifying a block of code that initializes it.
// Example:
//
// DEFINE_CHAR_PROPERTY(my_other_fave, prop) {
// for (int i = '0'; i <= '9'; i += 2) {
// prop->AddChar(i);
// }
// prop->AddAsciiPredicate(&ispunct);
// prop->AddCharProperty("currency_symbol");
// }
//
#define DEFINE_CHAR_PROPERTY(name, charpropvar) \
static void init_##name##_char_property(CharProperty *charpropvar); \
static utils::LazyStaticPtr<CharProperty, const char *, \
CharPropertyInitializer *> \
name##_char_property = {#name, &init_##name##_char_property}; \
REGISTER_CHAR_PROPERTY(name##_char_property, name); \
DEFINE_IS_X_CHAR_PROPERTY_FUNCTIONS(name##_char_property, name) \
static void init_##name##_char_property(CharProperty *charpropvar)
// =====================================================================
// Macro for declaring character properties
//
#define DECLARE_CHAR_PROPERTY(name) \
extern bool is_##name(const char *str, int len); \
extern bool is_##name(int c); \
// ===========================================================
// CharProperty - a property that holds for selected Unicode chars
//
// A CharProperty is semantically equivalent to set<char32>.
//
// The characters for which a CharProperty holds are represented as a trie,
// i.e., a tree that is indexed by successive bytes of the UTF-8 encoding
// of the characters. This permits fast lookup (HoldsFor).
//
// A function that defines a subset of [0..255], e.g., isspace.
typedef int AsciiPredicate(int c);
class CharProperty {
public:
// Constructor for set-based char properties.
CharProperty(const char *name, const int *unicodes, int num_unicodes);
// Constructor for function-based char properties.
CharProperty(const char *name, CharPropertyInitializer *init_fn);
virtual ~CharProperty();
// Various ways of adding chars to a CharProperty; for use only in
// CharPropertyInitializer functions.
void AddChar(int c);
void AddCharRange(int c1, int c2);
void AddAsciiPredicate(AsciiPredicate *pred);
void AddCharProperty(const char *name);
void AddCharSpec(const int *unicodes, int num_unicodes);
// Return true iff the CharProperty holds for a single given UTF8 char.
bool HoldsFor(const char *str, int len) const;
// Return true iff the CharProperty holds for a single given Unicode char.
bool HoldsFor(int c) const;
// You can use this to enumerate the set elements (it was easier
// than defining a real iterator). Returns -1 if there are no more.
// Call with -1 to get the first element. Expects c == -1 or HoldsFor(c).
int NextElementAfter(int c) const;
// Return NULL or the CharProperty with the given name. Looks up the name
// in a CharProperty registry.
static const CharProperty *Lookup(const char *name);
private:
void CheckUnicodeVal(int c) const;
static string UnicodeToString(int c);
const char *name_;
struct CharPropertyImplementation *impl_;
TF_DISALLOW_COPY_AND_ASSIGN(CharProperty);
};
//======================================================================
// Expression-level punctuation
//
// Punctuation that starts a sentence.
DECLARE_CHAR_PROPERTY(start_sentence_punc);
// Punctuation that ends a sentence.
DECLARE_CHAR_PROPERTY(end_sentence_punc);
// Punctuation, such as parens, that opens a "nested expression" of text.
DECLARE_CHAR_PROPERTY(open_expr_punc);
// Punctuation, such as parens, that closes a "nested expression" of text.
DECLARE_CHAR_PROPERTY(close_expr_punc);
// Chars that open a quotation.
DECLARE_CHAR_PROPERTY(open_quote);
// Chars that close a quotation.
DECLARE_CHAR_PROPERTY(close_quote);
// Punctuation chars that open an expression or a quotation.
DECLARE_CHAR_PROPERTY(open_punc);
// Punctuation chars that close an expression or a quotation.
DECLARE_CHAR_PROPERTY(close_punc);
// Punctuation chars that can come at the beginning of a sentence.
DECLARE_CHAR_PROPERTY(leading_sentence_punc);
// Punctuation chars that can come at the end of a sentence.
DECLARE_CHAR_PROPERTY(trailing_sentence_punc);
//======================================================================
// Token-level punctuation
//
// Token-prefix symbols -- glom on to following token
// (esp. if no space after) -- except for currency symbols.
DECLARE_CHAR_PROPERTY(noncurrency_token_prefix_symbol);
// Token-prefix symbols -- glom on to following token (esp. if no space after).
DECLARE_CHAR_PROPERTY(token_prefix_symbol);
// Token-suffix symbols -- glom on to preceding token (esp. if no space
// before).
DECLARE_CHAR_PROPERTY(token_suffix_symbol);
// Subscripts.
DECLARE_CHAR_PROPERTY(subscript_symbol);
// Superscripts.
DECLARE_CHAR_PROPERTY(superscript_symbol);
//======================================================================
// General punctuation
//
// Connector punctuation.
DECLARE_CHAR_PROPERTY(connector_punc);
// Dashes.
DECLARE_CHAR_PROPERTY(dash_punc);
// Other punctuation.
DECLARE_CHAR_PROPERTY(other_punc);
// All punctuation.
DECLARE_CHAR_PROPERTY(punctuation);
//======================================================================
// Special symbols
//
// Currency symbols.
DECLARE_CHAR_PROPERTY(currency_symbol);
// Chinese bookquotes.
DECLARE_CHAR_PROPERTY(open_bookquote);
DECLARE_CHAR_PROPERTY(close_bookquote);
//======================================================================
// Separators
//
// Line separators.
DECLARE_CHAR_PROPERTY(line_separator);
// Paragraph separators.
DECLARE_CHAR_PROPERTY(paragraph_separator);
// Space separators.
DECLARE_CHAR_PROPERTY(space_separator);
// Separators -- all line, paragraph, and space separators.
DECLARE_CHAR_PROPERTY(separator);
//======================================================================
// Alphanumeric Characters
//
// Digits.
DECLARE_CHAR_PROPERTY(digit);
// Japanese Katakana.
DECLARE_CHAR_PROPERTY(katakana);
//======================================================================
// BiDi Directional Formatting Codes
//
// Explicit directional formatting codes (LRM, RLM, LRE, RLE, PDF, LRO, RLO)
// used by the bidirectional algorithm.
//
// Note: Use this only to classify characters. To actually determine
// directionality of BiDi text, look under i18n/bidi.
//
// See http://www.unicode.org/reports/tr9/ for a description of the algorithm
// and http://www.unicode.org/charts/PDF/U2000.pdf for the character codes.
DECLARE_CHAR_PROPERTY(directional_formatting_code);
//======================================================================
// Special collections
//
// NB: This does not check for all punctuation and symbols in the standard;
// just those listed in our code. See the definitions in char_properties.cc.
DECLARE_CHAR_PROPERTY(punctuation_or_symbol);
} // namespace syntaxnet
#endif // SYNTAXNET_CHAR_PROPERTIES_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Tests for char_properties.cc:
//
// (1) Test the DEFINE_CHAR_PROPERTY_AS_SET and DEFINE_CHAR_PROPERTY macros
// by defining a few fake char properties and verifying their contents.
//
// (2) Test the char properties defined in char_properties.cc by spot-checking
// a few chars.
//
#include "syntaxnet/char_properties.h"
#include <ctype.h> // for ispunct, isspace
#include <map>
#include <set>
#include <utility>
#include <vector>
#include <gmock/gmock.h> // for ContainerEq, EXPECT_THAT
#include "tensorflow/core/platform/test.h"
#include "third_party/utf/utf.h"
#include "util/utf8/unilib.h" // for IsValidCodepoint, etc
#include "util/utf8/unilib_utf8_utils.h"
using ::testing::ContainerEq;
namespace syntaxnet {
// Invalid UTF-8 bytes are decoded as the Replacement Character, U+FFFD
// (which is also Runeerror). Invalid code points are encoded in UTF-8
// with the UTF-8 representation of the Replacement Character.
static const char ReplacementCharacterUTF8[3] = {'\xEF', '\xBF', '\xBD'};
// ====================================================================
// CharPropertiesTest
//
class CharPropertiesTest : public testing::Test {
protected:
// Collect a set of chars.
void CollectChars(const std::set<char32> &chars) {
collected_set_.insert(chars.begin(), chars.end());
}
// Collect an array of chars.
void CollectArray(const char32 arr[], int len) {
collected_set_.insert(arr, arr + len);
}
// Collect the chars for which the named CharProperty holds.
void CollectCharProperty(const char *name) {
const CharProperty *prop = CharProperty::Lookup(name);
ASSERT_TRUE(prop != nullptr) << "for " << name;
for (char32 c = 0; c <= 0x10FFFF; ++c) {
if (UniLib::IsValidCodepoint(c) && prop->HoldsFor(c)) {
collected_set_.insert(c);
}
}
}
// Collect the chars for which an ascii predicate holds.
void CollectAsciiPredicate(AsciiPredicate *pred) {
for (char32 c = 0; c < 256; ++c) {
if ((*pred)(c)) {
collected_set_.insert(c);
}
}
}
// Expect the named char property to be true for precisely the chars in
// the collected set.
void ExpectCharPropertyEqualsCollectedSet(const char *name) {
const CharProperty *prop = CharProperty::Lookup(name);
ASSERT_TRUE(prop != nullptr) << "for " << name;
// Test that char property holds for all collected chars. Exercises both
// signatures of CharProperty::HoldsFor().
for (std::set<char32>::const_iterator it = collected_set_.begin();
it != collected_set_.end(); ++it) {
// Test utf8 version of is_X().
const char32 c = *it;
string utf8_char = EncodeAsUTF8(&c, 1);
EXPECT_TRUE(prop->HoldsFor(utf8_char.c_str(), utf8_char.size()));
// Test ucs-2 version of is_X().
EXPECT_TRUE(prop->HoldsFor(static_cast<int>(c)));
}
// Test that the char property holds for precisely the collected chars.
// Somewhat redundant with previous test, but exercises
// CharProperty::NextElementAfter().
std::set<char32> actual_chars;
int c = -1;
while ((c = prop->NextElementAfter(c)) >= 0) {
actual_chars.insert(static_cast<char32>(c));
}
EXPECT_THAT(actual_chars, ContainerEq(collected_set_))
<< " for " << name;
}
// Expect the named char property to be true for at least the chars in
// the collected set.
void ExpectCharPropertyContainsCollectedSet(const char *name) {
const CharProperty *prop = CharProperty::Lookup(name);
ASSERT_TRUE(prop != nullptr) << "for " << name;
for (std::set<char32>::const_iterator it = collected_set_.begin();
it != collected_set_.end(); ++it) {
EXPECT_TRUE(prop->HoldsFor(static_cast<int>(*it)));
}
}
string EncodeAsUTF8(const char32 *in, int size) {
string out;
out.reserve(size);
for (int i = 0; i < size; ++i) {
char buf[UTFmax];
int len = EncodeAsUTF8Char(*in++, buf);
out.append(buf, len);
}
return out;
}
int EncodeAsUTF8Char(char32 in, char *out) {
if (UniLib::IsValidCodepoint(in)) {
return runetochar(out, &in);
} else {
memcpy(out, ReplacementCharacterUTF8, 3);
return 3;
}
}
private:
std::set<char32> collected_set_;
};
//======================================================================
// Declarations of the sample character sets below
// (to test the DECLARE_CHAR_PROPERTY() macro)
//
DECLARE_CHAR_PROPERTY(test_digit);
DECLARE_CHAR_PROPERTY(test_wavy_dash);
DECLARE_CHAR_PROPERTY(test_digit_or_wavy_dash);
DECLARE_CHAR_PROPERTY(test_punctuation_plus);
//======================================================================
// Definitions of sample character sets
//
// Digits.
DEFINE_CHAR_PROPERTY_AS_SET(test_digit,
RANGE('0', '9'),
)
// Wavy dashes.
DEFINE_CHAR_PROPERTY_AS_SET(test_wavy_dash,
'~',
0x301C, // wave dash
0x3030, // wavy dash
)
// Digits or wavy dashes.
DEFINE_CHAR_PROPERTY(test_digit_or_wavy_dash, prop) {
prop->AddCharProperty("test_digit");
prop->AddCharProperty("test_wavy_dash");
}
// Punctuation plus a few extraneous chars.
DEFINE_CHAR_PROPERTY(test_punctuation_plus, prop) {
prop->AddChar('a');
prop->AddCharRange('b', 'b');
prop->AddCharRange('c', 'e');
static const int kUnicodes[] = {'f', RANGE('g', 'i'), 'j'};
prop->AddCharSpec(kUnicodes, arraysize(kUnicodes));
prop->AddCharProperty("punctuation");
}
//====================================================================
// Another form of the character sets above -- for verification
//
const char32 kTestDigit[] = {
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'
};
const char32 kTestWavyDash[] = {
'~',
0x301C, // wave dash,
0x3030, // wavy dash
};
const char32 kTestPunctuationPlusExtras[] = {
'a',
'b',
'c',
'd',
'e',
'f',
'g',
'h',
'i',
'j',
};
// ====================================================================
// Tests
//
TEST_F(CharPropertiesTest, TestDigit) {
CollectArray(kTestDigit, arraysize(kTestDigit));
ExpectCharPropertyEqualsCollectedSet("test_digit");
}
TEST_F(CharPropertiesTest, TestWavyDash) {
CollectArray(kTestWavyDash, arraysize(kTestWavyDash));
ExpectCharPropertyEqualsCollectedSet("test_wavy_dash");
}
TEST_F(CharPropertiesTest, TestDigitOrWavyDash) {
CollectArray(kTestDigit, arraysize(kTestDigit));
CollectArray(kTestWavyDash, arraysize(kTestWavyDash));
ExpectCharPropertyEqualsCollectedSet("test_digit_or_wavy_dash");
}
TEST_F(CharPropertiesTest, TestPunctuationPlus) {
CollectCharProperty("punctuation");
CollectArray(kTestPunctuationPlusExtras,
arraysize(kTestPunctuationPlusExtras));
ExpectCharPropertyEqualsCollectedSet("test_punctuation_plus");
}
// ====================================================================
// Spot-check predicates in char_properties.cc
//
TEST_F(CharPropertiesTest, StartSentencePunc) {
CollectChars({0x00A1, 0x00BF});
ExpectCharPropertyContainsCollectedSet("start_sentence_punc");
}
TEST_F(CharPropertiesTest, EndSentencePunc) {
CollectChars({'.', '!', '?'});
ExpectCharPropertyContainsCollectedSet("end_sentence_punc");
}
TEST_F(CharPropertiesTest, OpenExprPunc) {
CollectChars({'(', '['});
ExpectCharPropertyContainsCollectedSet("open_expr_punc");
}
TEST_F(CharPropertiesTest, CloseExprPunc) {
CollectChars({')', ']'});
ExpectCharPropertyContainsCollectedSet("close_expr_punc");
}
TEST_F(CharPropertiesTest, OpenQuote) {
CollectChars({'\'', '"'});
ExpectCharPropertyContainsCollectedSet("open_quote");
}
TEST_F(CharPropertiesTest, CloseQuote) {
CollectChars({'\'', '"'});
ExpectCharPropertyContainsCollectedSet("close_quote");
}
TEST_F(CharPropertiesTest, OpenBookquote) {
CollectChars({0x300A});
ExpectCharPropertyContainsCollectedSet("open_bookquote");
}
TEST_F(CharPropertiesTest, CloseBookquote) {
CollectChars({0x300B});
ExpectCharPropertyContainsCollectedSet("close_bookquote");
}
TEST_F(CharPropertiesTest, OpenPunc) {
CollectChars({'(', '['});
CollectChars({'\'', '"'});
ExpectCharPropertyContainsCollectedSet("open_punc");
}
TEST_F(CharPropertiesTest, ClosePunc) {
CollectChars({')', ']'});
CollectChars({'\'', '"'});
ExpectCharPropertyContainsCollectedSet("close_punc");
}
TEST_F(CharPropertiesTest, LeadingSentencePunc) {
CollectChars({'(', '['});
CollectChars({'\'', '"'});
CollectChars({0x00A1, 0x00BF});
ExpectCharPropertyContainsCollectedSet("leading_sentence_punc");
}
TEST_F(CharPropertiesTest, TrailingSentencePunc) {
CollectChars({')', ']'});
CollectChars({'\'', '"'});
CollectChars({'.', '!', '?'});
ExpectCharPropertyContainsCollectedSet("trailing_sentence_punc");
}
TEST_F(CharPropertiesTest, NoncurrencyTokenPrefixSymbol) {
CollectChars({'#'});
ExpectCharPropertyContainsCollectedSet("noncurrency_token_prefix_symbol");
}
TEST_F(CharPropertiesTest, TokenSuffixSymbol) {
CollectChars({'%', 0x2122, 0x00A9, 0x00B0});
ExpectCharPropertyContainsCollectedSet("token_suffix_symbol");
}
TEST_F(CharPropertiesTest, TokenPrefixSymbol) {
CollectChars({'#'});
CollectChars({'$', 0x00A5, 0x20AC});
ExpectCharPropertyContainsCollectedSet("token_prefix_symbol");
}
TEST_F(CharPropertiesTest, SubscriptSymbol) {
CollectChars({0x2082, 0x2083});
ExpectCharPropertyContainsCollectedSet("subscript_symbol");
}
TEST_F(CharPropertiesTest, SuperscriptSymbol) {
CollectChars({0x00B2, 0x00B3});
ExpectCharPropertyContainsCollectedSet("superscript_symbol");
}
TEST_F(CharPropertiesTest, CurrencySymbol) {
CollectChars({'$', 0x00A5, 0x20AC});
ExpectCharPropertyContainsCollectedSet("currency_symbol");
}
TEST_F(CharPropertiesTest, DirectionalFormattingCode) {
CollectChars({0x200E, 0x200F, 0x202A, 0x202B, 0x202C, 0x202D, 0x202E});
ExpectCharPropertyContainsCollectedSet("directional_formatting_code");
}
TEST_F(CharPropertiesTest, Punctuation) {
CollectAsciiPredicate(ispunct);
ExpectCharPropertyContainsCollectedSet("punctuation");
}
TEST_F(CharPropertiesTest, Separator) {
CollectAsciiPredicate(isspace);
ExpectCharPropertyContainsCollectedSet("separator");
}
} // namespace syntaxnet
...@@ -77,7 +77,8 @@ class DocumentSource : public OpKernel { ...@@ -77,7 +77,8 @@ class DocumentSource : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("batch_size", &batch_size_)); OP_REQUIRES_OK(context, context->GetAttr("batch_size", &batch_size_));
OP_REQUIRES(context, batch_size_ > 0, OP_REQUIRES(context, batch_size_ > 0,
InvalidArgument("invalid batch_size provided")); InvalidArgument("invalid batch_size provided"));
corpus_.reset(new TextReader(*task_context_.GetInput(corpus_name))); corpus_.reset(
new TextReader(*task_context_.GetInput(corpus_name), &task_context_));
} }
void Compute(OpKernelContext *context) override { void Compute(OpKernelContext *context) override {
...@@ -124,7 +125,8 @@ class DocumentSink : public OpKernel { ...@@ -124,7 +125,8 @@ class DocumentSink : public OpKernel {
GetTaskContext(context, &task_context_); GetTaskContext(context, &task_context_);
string corpus_name; string corpus_name;
OP_REQUIRES_OK(context, context->GetAttr("corpus_name", &corpus_name)); OP_REQUIRES_OK(context, context->GetAttr("corpus_name", &corpus_name));
writer_.reset(new TextWriter(*task_context_.GetInput(corpus_name))); writer_.reset(
new TextWriter(*task_context_.GetInput(corpus_name), &task_context_));
} }
void Compute(OpKernelContext *context) override { void Compute(OpKernelContext *context) override {
......
...@@ -38,6 +38,8 @@ class DocumentFormat : public RegisterableClass<DocumentFormat> { ...@@ -38,6 +38,8 @@ class DocumentFormat : public RegisterableClass<DocumentFormat> {
DocumentFormat() {} DocumentFormat() {}
virtual ~DocumentFormat() {} virtual ~DocumentFormat() {}
virtual void Setup(TaskContext *context) {}
// Reads a record from the given input buffer with format specific logic. // Reads a record from the given input buffer with format specific logic.
// Returns false if no record could be read because we reached end of file. // Returns false if no record could be read because we reached end of file.
virtual bool ReadRecord(tensorflow::io::InputBuffer *buffer, virtual bool ReadRecord(tensorflow::io::InputBuffer *buffer,
......
...@@ -19,6 +19,7 @@ limitations under the License. ...@@ -19,6 +19,7 @@ limitations under the License.
#include "syntaxnet/affix.h" #include "syntaxnet/affix.h"
#include "syntaxnet/dictionary.pb.h" #include "syntaxnet/dictionary.pb.h"
#include "syntaxnet/feature_extractor.h" #include "syntaxnet/feature_extractor.h"
#include "syntaxnet/segmenter_utils.h"
#include "syntaxnet/sentence.pb.h" #include "syntaxnet/sentence.pb.h"
#include "syntaxnet/sentence_batch.h" #include "syntaxnet/sentence_batch.h"
#include "syntaxnet/term_frequency_map.h" #include "syntaxnet/term_frequency_map.h"
...@@ -75,6 +76,7 @@ class LexiconBuilder : public OpKernel { ...@@ -75,6 +76,7 @@ class LexiconBuilder : public OpKernel {
TermFrequencyMap tags; TermFrequencyMap tags;
TermFrequencyMap categories; TermFrequencyMap categories;
TermFrequencyMap labels; TermFrequencyMap labels;
TermFrequencyMap chars;
// Affix tables to be populated by the corpus. // Affix tables to be populated by the corpus.
AffixTable prefixes(AffixTable::PREFIX, max_prefix_length_); AffixTable prefixes(AffixTable::PREFIX, max_prefix_length_);
...@@ -87,7 +89,7 @@ class LexiconBuilder : public OpKernel { ...@@ -87,7 +89,7 @@ class LexiconBuilder : public OpKernel {
int64 num_tokens = 0; int64 num_tokens = 0;
int64 num_documents = 0; int64 num_documents = 0;
Sentence *document; Sentence *document;
TextReader corpus(*task_context_.GetInput(corpus_name_)); TextReader corpus(*task_context_.GetInput(corpus_name_), &task_context_);
while ((document = corpus.Read()) != nullptr) { while ((document = corpus.Read()) != nullptr) {
// Gather token information. // Gather token information.
for (int t = 0; t < document->token_size(); ++t) { for (int t = 0; t < document->token_size(); ++t) {
...@@ -114,6 +116,14 @@ class LexiconBuilder : public OpKernel { ...@@ -114,6 +116,14 @@ class LexiconBuilder : public OpKernel {
// Add mapping from tag to category. // Add mapping from tag to category.
tag_to_category.SetCategory(token.tag(), token.category()); tag_to_category.SetCategory(token.tag(), token.category());
// Add characters.
vector<tensorflow::StringPiece> char_sp;
SegmenterUtils::GetUTF8Chars(word, &char_sp);
for (const auto &c : char_sp) {
const string c_str = c.ToString();
if (!c_str.empty() && !HasSpaces(c_str)) chars.Increment(c_str);
}
// Update the number of processed tokens. // Update the number of processed tokens.
++num_tokens; ++num_tokens;
} }
...@@ -131,6 +141,7 @@ class LexiconBuilder : public OpKernel { ...@@ -131,6 +141,7 @@ class LexiconBuilder : public OpKernel {
categories.Save( categories.Save(
TaskContext::InputFile(*task_context_.GetInput("category-map"))); TaskContext::InputFile(*task_context_.GetInput("category-map")));
labels.Save(TaskContext::InputFile(*task_context_.GetInput("label-map"))); labels.Save(TaskContext::InputFile(*task_context_.GetInput("label-map")));
chars.Save(TaskContext::InputFile(*task_context_.GetInput("char-map")));
// Write affixes to disk. // Write affixes to disk.
WriteAffixTable(prefixes, TaskContext::InputFile( WriteAffixTable(prefixes, TaskContext::InputFile(
......
...@@ -69,6 +69,8 @@ TOKENIZED_DOCS = u'''बात गलत हो तो गुस्सा से ...@@ -69,6 +69,8 @@ TOKENIZED_DOCS = u'''बात गलत हो तो गुस्सा से
लेकिन अभिनेत्री के इस कदम से वहां रंग में भंग पड़ गया । लेकिन अभिनेत्री के इस कदम से वहां रंग में भंग पड़ गया ।
''' '''
CHARS = u'''अ इ आ क ग ज ट त द न प भ ब य म र ल व ह स ि ा ु ी े ै ो ् ड़ । ं'''
COMMENTS = u'# Line with fake comments.' COMMENTS = u'# Line with fake comments.'
...@@ -93,7 +95,7 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase): ...@@ -93,7 +95,7 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
self.AddInput('documents', self.corpus_file, corpus_format, context) self.AddInput('documents', self.corpus_file, corpus_format, context)
for name in ('word-map', 'lcword-map', 'tag-map', for name in ('word-map', 'lcword-map', 'tag-map',
'category-map', 'label-map', 'prefix-table', 'category-map', 'label-map', 'prefix-table',
'suffix-table', 'tag-to-category'): 'suffix-table', 'tag-to-category', 'char-map'):
self.AddInput(name, os.path.join(FLAGS.test_tmpdir, name), '', context) self.AddInput(name, os.path.join(FLAGS.test_tmpdir, name), '', context)
logging.info('Writing context to: %s', self.context_file) logging.info('Writing context to: %s', self.context_file)
with open(self.context_file, 'w') as f: with open(self.context_file, 'w') as f:
...@@ -133,6 +135,26 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase): ...@@ -133,6 +135,26 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
self.assertIn(tag, TAGS) self.assertIn(tag, TAGS)
self.assertIn(category, CATEGORIES) self.assertIn(category, CATEGORIES)
def LoadMap(self, map_name):
loaded_map = {}
with file(os.path.join(FLAGS.test_tmpdir, map_name), 'r') as f:
for line in f:
entries = line.strip().split(' ')
if len(entries) == 2:
loaded_map[entries[0]] = entries[1]
return loaded_map
def ValidateCharMap(self):
char_map = self.LoadMap('char-map')
self.assertEqual(len(char_map), len(CHARS.split(' ')))
for char in CHARS.split(' '):
self.assertIn(char.encode('utf-8'), char_map)
def ValidateWordMap(self):
word_map = self.LoadMap('word-map')
for word in filter(None, TOKENIZED_DOCS.replace('\n', ' ').split(' ')):
self.assertIn(word.encode('utf-8'), word_map)
def BuildLexicon(self): def BuildLexicon(self):
with self.test_session(): with self.test_session():
gen_parser_ops.lexicon_builder(task_context=self.context_file).run() gen_parser_ops.lexicon_builder(task_context=self.context_file).run()
...@@ -146,6 +168,8 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase): ...@@ -146,6 +168,8 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
self.ValidateDocuments() self.ValidateDocuments()
self.BuildLexicon() self.BuildLexicon()
self.ValidateTagToCategoryMap() self.ValidateTagToCategoryMap()
self.ValidateCharMap()
self.ValidateWordMap()
def testCoNLLFormatExtraNewlinesAndComments(self): def testCoNLLFormatExtraNewlinesAndComments(self):
self.WriteContext('conll-sentence') self.WriteContext('conll-sentence')
......
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Morpher transition system.
//
// This transition system has one type of actions:
// - The SHIFT action pushes the next input token to the stack and
// advances to the next input token, assigning a part-of-speech tag to the
// token that was shifted.
//
// The transition system operates with parser actions encoded as integers:
// - A SHIFT action is encoded as number starting from 0.
#include <string>
#include "syntaxnet/morphology_label_set.h"
#include "syntaxnet/parser_features.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/sentence_features.h"
#include "syntaxnet/shared_store.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/term_frequency_map.h"
#include "syntaxnet/utils.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
class MorphologyTransitionState : public ParserTransitionState {
public:
explicit MorphologyTransitionState(const MorphologyLabelSet *label_set)
: label_set_(label_set) {}
explicit MorphologyTransitionState(const MorphologyTransitionState *state)
: MorphologyTransitionState(state->label_set_) {
tag_ = state->tag_;
gold_tag_ = state->gold_tag_;
}
// Clones the transition state by returning a new object.
ParserTransitionState *Clone() const override {
return new MorphologyTransitionState(this);
}
// Reads gold tags for each token.
void Init(ParserState *state) override {
tag_.resize(state->sentence().token_size(), -1);
gold_tag_.resize(state->sentence().token_size(), -1);
for (int pos = 0; pos < state->sentence().token_size(); ++pos) {
const Token &token = state->GetToken(pos);
// NOTE: we allow token to not have a TokenMorphology extension or for the
// TokenMorphology to be absent from the label_set_ because this can
// happen at test time.
gold_tag_[pos] = label_set_->LookupExisting(
token.GetExtension(TokenMorphology::morphology));
}
}
// Returns the tag assigned to a given token.
int Tag(int index) const {
DCHECK_GE(index, 0);
DCHECK_LT(index, tag_.size());
return index == -1 ? -1 : tag_[index];
}
// Sets this tag on the token at index.
void SetTag(int index, int tag) {
DCHECK_GE(index, 0);
DCHECK_LT(index, tag_.size());
tag_[index] = tag;
}
// Returns the gold tag for a given token.
int GoldTag(int index) const {
DCHECK_GE(index, -1);
DCHECK_LT(index, gold_tag_.size());
return index == -1 ? -1 : gold_tag_[index];
}
// Returns the proto corresponding to the tag, or an empty proto if the tag is
// not found.
const TokenMorphology &TagAsProto(int tag) const {
if (tag >= 0 && tag < label_set_->Size()) {
return label_set_->Lookup(tag);
}
return TokenMorphology::default_instance();
}
// Adds transition state specific annotations to the document.
void AddParseToDocument(const ParserState &state, bool rewrite_root_labels,
Sentence *sentence) const override {
for (int i = 0; i < tag_.size(); ++i) {
Token *token = sentence->mutable_token(i);
*token->MutableExtension(TokenMorphology::morphology) =
TagAsProto(Tag(i));
}
}
// Whether a parsed token should be considered correct for evaluation.
bool IsTokenCorrect(const ParserState &state, int index) const override {
return GoldTag(index) == Tag(index);
}
// Returns a human readable string representation of this state.
string ToString(const ParserState &state) const override {
string str;
for (int i = state.StackSize(); i > 0; --i) {
const string &word = state.GetToken(state.Stack(i - 1)).word();
if (i != state.StackSize() - 1) str.append(" ");
tensorflow::strings::StrAppend(
&str, word, "[",
TagAsProto(Tag(state.StackSize() - i)).ShortDebugString(), "]");
}
for (int i = state.Next(); i < state.NumTokens(); ++i) {
tensorflow::strings::StrAppend(&str, " ", state.GetToken(i).word());
}
return str;
}
private:
// Currently assigned morphological analysis for each token in this sentence.
vector<int> tag_;
// Gold morphological analysis from the input document.
vector<int> gold_tag_;
// Tag map used for conversions between integer and string representations
// part of speech tags. Not owned.
const MorphologyLabelSet *label_set_ = nullptr;
TF_DISALLOW_COPY_AND_ASSIGN(MorphologyTransitionState);
};
class MorphologyTransitionSystem : public ParserTransitionSystem {
public:
~MorphologyTransitionSystem() override { SharedStore::Release(label_set_); }
// Determines tag map location.
void Setup(TaskContext *context) override {
context->GetInput("morph-label-set");
}
// Reads tag map and tag to category map.
void Init(TaskContext *context) override {
const string fname =
TaskContext::InputFile(*context->GetInput("morph-label-set"));
label_set_ =
SharedStoreUtils::GetWithDefaultName<MorphologyLabelSet>(fname);
}
// The SHIFT action uses the same value as the corresponding action type.
static ParserAction ShiftAction(int tag) { return tag; }
// The morpher transition system doesn't look at the dependency tree, so it
// allows non-projective trees.
bool AllowsNonProjective() const override { return true; }
// Returns the number of action types.
int NumActionTypes() const override { return 1; }
// Returns the number of possible actions.
int NumActions(int num_labels) const override { return label_set_->Size(); }
// The default action for a given state is assigning the most frequent tag.
ParserAction GetDefaultAction(const ParserState &state) const override {
return ShiftAction(0);
}
// Returns the next gold action for a given state according to the
// underlying annotated sentence.
ParserAction GetNextGoldAction(const ParserState &state) const override {
if (!state.EndOfInput()) {
return ShiftAction(TransitionState(state).GoldTag(state.Next()));
}
return ShiftAction(0);
}
// Checks if the action is allowed in a given parser state.
bool IsAllowedAction(ParserAction action,
const ParserState &state) const override {
return !state.EndOfInput();
}
// Makes a shift by pushing the next input token on the stack and moving to
// the next position.
void PerformActionWithoutHistory(ParserAction action,
ParserState *state) const override {
DCHECK(!state->EndOfInput());
if (!state->EndOfInput()) {
MutableTransitionState(state)->SetTag(state->Next(), action);
state->Push(state->Next());
state->Advance();
}
}
// We are in a final state when we reached the end of the input and the stack
// is empty.
bool IsFinalState(const ParserState &state) const override {
return state.EndOfInput();
}
// Returns a string representation of a parser action.
string ActionAsString(ParserAction action,
const ParserState &state) const override {
return tensorflow::strings::StrCat(
"SHIFT(", label_set_->Lookup(action).ShortDebugString(), ")");
}
// No state is deterministic in this transition system.
bool IsDeterministicState(const ParserState &state) const override {
return false;
}
// Returns a new transition state to be used to enhance the parser state.
ParserTransitionState *NewTransitionState(bool training_mode) const override {
return new MorphologyTransitionState(label_set_);
}
// Downcasts the const ParserTransitionState in ParserState to a const
// MorphologyTransitionState.
static const MorphologyTransitionState &TransitionState(
const ParserState &state) {
return *static_cast<const MorphologyTransitionState *>(
state.transition_state());
}
// Downcasts the ParserTransitionState in ParserState to an
// MorphologyTransitionState.
static MorphologyTransitionState *MutableTransitionState(ParserState *state) {
return static_cast<MorphologyTransitionState *>(
state->mutable_transition_state());
}
// Input for the tag map. Not owned.
TaskInput *input_label_set_ = nullptr;
// Tag map used for conversions between integer and string representations
// morphology labels. Owned through SharedStore.
const MorphologyLabelSet *label_set_;
};
REGISTER_TRANSITION_SYSTEM("morpher", MorphologyTransitionSystem);
// Feature function for retrieving the tag assigned to a token by the tagger
// transition system.
class PredictedMorphTagFeatureFunction : public ParserIndexFeatureFunction {
public:
PredictedMorphTagFeatureFunction() {}
// Determines tag map location.
void Setup(TaskContext *context) override {
context->GetInput("morph-label-set", "recordio", "token-morphology");
}
// Reads tag map.
void Init(TaskContext *context) override {
const string fname =
TaskContext::InputFile(*context->GetInput("morph-label-set"));
label_set_ = SharedStore::Get<MorphologyLabelSet>(fname, fname);
set_feature_type(new FullLabelFeatureType(name(), label_set_));
}
// Gets the MorphologyTransitionState from the parser state and reads the
// assigned
// tag at the focus index. Returns -1 if the focus is not within the sentence.
FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
int focus, const FeatureVector *result) const override {
if (focus < 0 || focus >= state.sentence().token_size()) return -1;
return static_cast<const MorphologyTransitionState *>(
state.transition_state())
->Tag(focus);
}
private:
// Tag map used for conversions between integer and string representations
// part of speech tags. Owned through SharedStore.
const MorphologyLabelSet *label_set_;
TF_DISALLOW_COPY_AND_ASSIGN(PredictedMorphTagFeatureFunction);
};
REGISTER_PARSER_IDX_FEATURE_FUNCTION("pred-morph-tag",
PredictedMorphTagFeatureFunction);
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/morphology_label_set.h"
namespace syntaxnet {
const char MorphologyLabelSet::kSeparator[] = "\t";
int MorphologyLabelSet::Add(const TokenMorphology &morph) {
string repr = StringForMatch(morph);
auto it = fast_lookup_.find(repr);
if (it != fast_lookup_.end()) return it->second;
fast_lookup_[repr] = label_set_.size();
label_set_.push_back(morph);
return label_set_.size() - 1;
}
// Look up an existing TokenMorphology. If it is not present, return -1.
int MorphologyLabelSet::LookupExisting(const TokenMorphology &morph) const {
string repr = StringForMatch(morph);
auto it = fast_lookup_.find(repr);
if (it != fast_lookup_.end()) return it->second;
return -1;
}
// Return the TokenMorphology at position i. The input i should be in the range
// 0..size().
const TokenMorphology &MorphologyLabelSet::Lookup(int i) const {
CHECK_GE(i, 0);
CHECK_LT(i, label_set_.size());
return label_set_[i];
}
void MorphologyLabelSet::Read(const string &filename) {
ProtoRecordReader reader(filename);
Read(&reader);
}
void MorphologyLabelSet::Read(ProtoRecordReader *reader) {
TokenMorphology morph;
while (reader->Read(&morph).ok()) {
CHECK_EQ(-1, LookupExisting(morph));
Add(morph);
}
}
void MorphologyLabelSet::Write(const string &filename) const {
ProtoRecordWriter writer(filename);
Write(&writer);
}
void MorphologyLabelSet::Write(ProtoRecordWriter *writer) const {
for (const TokenMorphology &morph : label_set_) {
writer->Write(morph);
}
}
string MorphologyLabelSet::StringForMatch(const TokenMorphology &morph) const {
vector<string> attributes;
for (const auto &a : morph.attribute()) {
attributes.push_back(
tensorflow::strings::StrCat(a.name(), kSeparator, a.value()));
}
std::sort(attributes.begin(), attributes.end());
return utils::Join(attributes, kSeparator);
}
string FullLabelFeatureType::GetFeatureValueName(FeatureValue value) const {
const TokenMorphology &morph = label_set_->Lookup(value);
vector<string> attributes;
for (const auto &a : morph.attribute()) {
attributes.push_back(tensorflow::strings::StrCat(a.name(), ":", a.value()));
}
std::sort(attributes.begin(), attributes.end());
return utils::Join(attributes, ",");
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A class to store the set of possible TokenMorphology objects. This includes
// lookup, iteration and serialziation.
#ifndef SYNTAXNET_MORPHOLOGY_LABEL_SET_H_
#define SYNTAXNET_MORPHOLOGY_LABEL_SET_H_
#include <unordered_map>
#include <string>
#include <vector>
#include "syntaxnet/proto_io.h"
#include "syntaxnet/sentence.pb.h"
namespace syntaxnet {
class MorphologyLabelSet {
public:
// Initalize as an empty morphology.
MorphologyLabelSet() {}
// Initalizes by reading the given file, which has been saved by Write().
// This makes using the shared store easier.
explicit MorphologyLabelSet(const string &fname) { Read(fname); }
// Adds a TokenMorphology to the set if it is not present. In any case, return
// its position in the list. Note: This is slow, and should not be called
// outside of training or init.
int Add(const TokenMorphology &morph);
// Look up an existing TokenMorphology. If it is not present, return -1.
// Note: This is slow, and should not be called outside of training workflow
// or init.
int LookupExisting(const TokenMorphology &morph) const;
// Return the TokenMorphology at position i. The input i should be in the
// range 0..size(). Note: this will be called at inference time and needs to
// be kept fast.
const TokenMorphology &Lookup(int i) const;
// Return the number of elements.
int Size() const { return label_set_.size(); }
// Deserialization and serialization.
void Read(const string &filename);
void Write(const string &filename) const;
private:
string StringForMatch(const TokenMorphology &morhp) const;
// Deserialization and serialziation implementation.
void Read(ProtoRecordReader *reader);
void Write(ProtoRecordWriter *writer) const;
// List of all possible annotations. This is a unique list, where equality is
// defined as follows:
//
// a == b iff the set of attribute pairs (attribute, value) is identical.
vector<TokenMorphology> label_set_;
// Because protocol buffer equality is complicated, we implement our own
// equality operator based on strings. This unordered_map allows us to do the
// lookup more quickly.
unordered_map<string, int> fast_lookup_;
// A separator string that should not occur in any of the attribute names.
// This should never be serialized, so that it can be changed in the code if
// we change attribute names and it occurs in the new names.
static const char kSeparator[];
};
// A feature type with one value for each complete morphological analysis
// (analogous to the fulltag analyzer).
class FullLabelFeatureType : public FeatureType {
public:
FullLabelFeatureType(const string &name, const MorphologyLabelSet *label_set)
: FeatureType(name), label_set_(label_set) {}
~FullLabelFeatureType() override {}
// Converts a feature value to a name. We don't use StringForMatch, since the
// goal of these are to be readable, even if they might occasionally be
// non-unique.
string GetFeatureValueName(FeatureValue value) const override;
// Returns the size of the feature values domain.
FeatureValue GetDomainSize() const override { return label_set_->Size(); }
private:
// Not owned.
const MorphologyLabelSet *label_set_ = nullptr;
};
} // namespace syntaxnet
#endif // SYNTAXNET_MORPHOLOGY_LABEL_SET_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/morphology_label_set.h"
#include "syntaxnet/sentence.pb.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
class MorphologyLabelSetTest : public ::testing::Test {
protected:
MorphologyLabelSet label_set_;
};
// Test that Add and LookupExisting work as expected.
TEST_F(MorphologyLabelSetTest, AddLookupExisting) {
TokenMorphology si1, si2; // singular, imperative
TokenMorphology pi; // plural, imperative
TokenMorphology six; // singular, imperative with extra value
TextFormat::ParseFromString(R"(
attribute {name: "Number" value: "Singular"}
attribute {name: "POS" value: "IMP"})",
&si1);
TextFormat::ParseFromString(R"(
attribute {name: "POS" value: "IMP"}
attribute {name: "Number" value: "Singular"})",
&si2);
TextFormat::ParseFromString(R"(
attribute {name: "Number" value: "Plural"}
attribute {name: "POS" value: "IMP"})",
&pi);
TextFormat::ParseFromString(R"(
attribute {name: "Number" value: "Plural"}
attribute {name: "POS" value: "IMP"}
attribute {name: "x" value: "x"})",
&six);
// Check Lookup existing returns -1 for non-existing entries.
EXPECT_EQ(-1, label_set_.LookupExisting(si1));
EXPECT_EQ(-1, label_set_.LookupExisting(si2));
EXPECT_EQ(0, label_set_.Size());
// Check that adding returns 0 (this is the only possiblity given Size())
EXPECT_EQ(0, label_set_.Add(si1));
EXPECT_EQ(0, label_set_.Add(si1)); // calling Add twice adds only once
EXPECT_EQ(1, label_set_.Size());
// Check that order of attributes does not matter.
EXPECT_EQ(0, label_set_.LookupExisting(si2));
// Check that un-added entries still are not present.
EXPECT_EQ(-1, label_set_.LookupExisting(pi));
EXPECT_EQ(-1, label_set_.LookupExisting(six));
// Check that we can add them.
EXPECT_EQ(1, label_set_.Add(pi));
EXPECT_EQ(2, label_set_.Add(six));
EXPECT_EQ(3, label_set_.Size());
}
// Test write and deserializing constructor.
TEST_F(MorphologyLabelSetTest, Serialization) {
TokenMorphology si; // singular, imperative
TokenMorphology pi; // plural, imperative
TextFormat::ParseFromString(R"(
attribute {name: "Number" value: "Singular"}
attribute {name: "POS" value: "IMP"})",
&si);
TextFormat::ParseFromString(R"(
attribute {name: "Number" value: "Plural"}
attribute {name: "POS" value: "IMP"})",
&pi);
EXPECT_EQ(0, label_set_.Add(si));
EXPECT_EQ(1, label_set_.Add(pi));
// Serialize and deserialize.
string fname = utils::JoinPath({tensorflow::testing::TmpDir(), "label-set"});
label_set_.Write(fname);
MorphologyLabelSet label_set2(fname);
EXPECT_EQ(0, label_set2.LookupExisting(si));
EXPECT_EQ(1, label_set2.LookupExisting(pi));
EXPECT_EQ(2, label_set2.Size());
}
} // namespace syntaxnet
...@@ -22,7 +22,6 @@ import time ...@@ -22,7 +22,6 @@ import time
import tensorflow as tf import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from syntaxnet import sentence_pb2 from syntaxnet import sentence_pb2
from syntaxnet import graph_builder from syntaxnet import graph_builder
......
...@@ -166,6 +166,9 @@ REGISTER_PARSER_IDX_FEATURE_FUNCTION("label", LabelFeatureFunction); ...@@ -166,6 +166,9 @@ REGISTER_PARSER_IDX_FEATURE_FUNCTION("label", LabelFeatureFunction);
typedef BasicParserSentenceFeatureFunction<Word> WordFeatureFunction; typedef BasicParserSentenceFeatureFunction<Word> WordFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("word", WordFeatureFunction); REGISTER_PARSER_IDX_FEATURE_FUNCTION("word", WordFeatureFunction);
typedef BasicParserSentenceFeatureFunction<Char> CharFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("char", CharFeatureFunction);
typedef BasicParserSentenceFeatureFunction<Tag> TagFeatureFunction; typedef BasicParserSentenceFeatureFunction<Tag> TagFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("tag", TagFeatureFunction); REGISTER_PARSER_IDX_FEATURE_FUNCTION("tag", TagFeatureFunction);
...@@ -175,6 +178,21 @@ REGISTER_PARSER_IDX_FEATURE_FUNCTION("digit", DigitFeatureFunction); ...@@ -175,6 +178,21 @@ REGISTER_PARSER_IDX_FEATURE_FUNCTION("digit", DigitFeatureFunction);
typedef BasicParserSentenceFeatureFunction<Hyphen> HyphenFeatureFunction; typedef BasicParserSentenceFeatureFunction<Hyphen> HyphenFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("hyphen", HyphenFeatureFunction); REGISTER_PARSER_IDX_FEATURE_FUNCTION("hyphen", HyphenFeatureFunction);
typedef BasicParserSentenceFeatureFunction<Capitalization>
CapitalizationFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("capitalization",
CapitalizationFeatureFunction);
typedef BasicParserSentenceFeatureFunction<PunctuationAmount>
PunctuationAmountFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("punctuation-amount",
PunctuationAmountFeatureFunction);
typedef BasicParserSentenceFeatureFunction<Quote>
QuoteFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("quote",
QuoteFeatureFunction);
typedef BasicParserSentenceFeatureFunction<PrefixFeature> PrefixFeatureFunction; typedef BasicParserSentenceFeatureFunction<PrefixFeature> PrefixFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("prefix", PrefixFeatureFunction); REGISTER_PARSER_IDX_FEATURE_FUNCTION("prefix", PrefixFeatureFunction);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment