Commit 6f6a4539 authored by calberti's avatar calberti Committed by GitHub
Browse files

New features, locator and text format for tokenization (#306)

* Adding:
  - offset feature locator,
  - last word feature function,
  - untokenized text format.
parent c4d943b2
......@@ -132,6 +132,7 @@ cc_library(
srcs = ["text_formats.cc"],
deps = [
":document_format",
":segmenter_utils",
":sentence_proto",
],
alwayslink = 1,
......@@ -558,7 +559,9 @@ cc_test(
deps = [
":parser_transitions",
":sentence_proto",
":task_context",
":test_main",
":workspace",
],
)
......@@ -646,6 +649,7 @@ py_binary(
":graph_builder",
":sentence_py_pb2",
":structured_graph_builder",
":task_spec_py_pb2",
],
)
......
......@@ -14,8 +14,10 @@ limitations under the License.
==============================================================================*/
#include "syntaxnet/binary_segment_state.h"
#include "syntaxnet/parser_features.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/term_frequency_map.h"
namespace syntaxnet {
......@@ -118,4 +120,102 @@ class BinarySegmentTransitionSystem : public ParserTransitionSystem {
REGISTER_TRANSITION_SYSTEM("binary-segment-transitions",
BinarySegmentTransitionSystem);
// Parser feature locator that returns the token in the sentence that is
// argument() positions from the provided focus token.
class OffsetFeatureLocator : public ParserIndexLocator<OffsetFeatureLocator> {
public:
// Update the current focus to a new location. If the initial focus or new
// focus is outside the range of the sentence, returns -2.
void UpdateArgs(const WorkspaceSet &workspaces, const ParserState &state,
int *focus) const {
if (*focus < -1 || *focus >= state.sentence().token_size()) {
*focus = -2;
return;
}
int new_focus = *focus + argument();
if (new_focus < -1 || new_focus >= state.sentence().token_size()) {
*focus = -2;
return;
}
*focus = new_focus;
}
};
REGISTER_PARSER_IDX_FEATURE_FUNCTION("offset", OffsetFeatureLocator);
// Feature function that returns the id of the n-th most recently constructed
// word. Note that the argument, n, should be larger than 0. When equals to 0,
// it points to the word which is not yet completed.
class LastWordFeatureFunction : public ParserFeatureFunction {
public:
void Setup(TaskContext *context) override {
input_word_map_ = context->GetInput("word-map", "text", "");
}
void Init(TaskContext *context) override {
min_freq_ = GetIntParameter("min-freq", 0);
max_num_terms_ = GetIntParameter("max-num-terms", 0);
word_map_.Load(
TaskContext::InputFile(*input_word_map_), min_freq_, max_num_terms_);
unk_id_ = word_map_.Size();
outside_id_ = unk_id_ + 1;
set_feature_type(
new ResourceBasedFeatureType<LastWordFeatureFunction>(
name(), this, {}));
}
int64 NumValues() const {
return outside_id_ + 1;
}
// Returns the string representation of the given feature value.
string GetFeatureValueName(FeatureValue value) const {
if (value == outside_id_) return "<OUTSIDE>";
if (value == unk_id_) return "<UNKNOWN>";
DCHECK_GE(value, 0);
DCHECK_LT(value, word_map_.Size());
return word_map_.GetTerm(value);
}
FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
const FeatureVector *result) const override {
// n should be larger than 0, since the current word is still under
// construction.
const int n = argument();
CHECK_GT(n, 0);
const auto *segment_state = static_cast<const BinarySegmentState *>(
state.transition_state());
if (n >= segment_state->NumStarts(state)) {
return outside_id_;
}
const auto &sentence = state.sentence();
const int start = segment_state->LastStart(n, state);
const int end = segment_state->LastStart(n - 1, state) - 1;
CHECK_GE(end, start);
const int start_offset = state.GetToken(start).start();
const int length = state.GetToken(end).end() - start_offset + 1;
const auto *data = sentence.text().data() + start_offset;
return word_map_.LookupIndex(string(data, length), unk_id_);
}
private:
// Task input for the word to id map. Not owned.
TaskInput *input_word_map_ = nullptr;
TermFrequencyMap word_map_;
// Special ids of unknown words and out-of-range.
int unk_id_ = 0;
int outside_id_ = 0;
// Minimum frequency for term map.
int min_freq_;
// Maximum number of terms for term map.
int max_num_terms_;
};
REGISTER_PARSER_FEATURE_FUNCTION("last-word", LastWordFeatureFunction);
} // namespace syntaxnet
......@@ -14,9 +14,12 @@ limitations under the License.
==============================================================================*/
#include "syntaxnet/binary_segment_state.h"
#include "syntaxnet/parser_features.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/term_frequency_map.h"
#include "syntaxnet/workspace.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
......@@ -38,6 +41,58 @@ class SegmentationTransitionTest : public ::testing::Test {
"token { word: '样' start: 14 end: 16 break_level: NO_BREAK } ";
sentence_ = std::unique_ptr<Sentence>(new Sentence());
TextFormat::ParseFromString(str_sentence, sentence_.get());
word_map_.Increment("因为");
word_map_.Increment("因为");
word_map_.Increment("有");
word_map_.Increment("这");
word_map_.Increment("这");
word_map_.Increment("样");
word_map_.Increment("样");
word_map_.Increment("这样");
word_map_.Increment("这样");
string filename = tensorflow::strings::StrCat(
tensorflow::testing::TmpDir(), "word-map");
word_map_.Save(filename);
// Re-load in sorted order, ignore words that only occurs once.
word_map_.Load(filename, 2, -1);
// Prepare task context.
context_ = std::unique_ptr<TaskContext>(new TaskContext());
AddInputToContext("word-map", filename, "text", "");
registry_ = std::unique_ptr<WorkspaceRegistry>( new WorkspaceRegistry());
}
// Adds an input to the task context.
void AddInputToContext(const string &name,
const string &file_pattern,
const string &file_format,
const string &record_format) {
TaskInput *input = context_->GetInput(name);
TaskInput::Part *part = input->add_part();
part->set_file_pattern(file_pattern);
part->set_file_format(file_format);
part->set_record_format(record_format);
}
// Prepares a feature for computations.
void PrepareFeature(const string &feature_name, ParserState *state) {
feature_extractor_ = std::unique_ptr<ParserFeatureExtractor>(
new ParserFeatureExtractor());
feature_extractor_->Parse(feature_name);
feature_extractor_->Setup(context_.get());
feature_extractor_->Init(context_.get());
feature_extractor_->RequestWorkspaces(registry_.get());
workspace_.Reset(*registry_);
feature_extractor_->Preprocess(&workspace_, state);
}
// Computes the feature value for the parser state.
FeatureValue ComputeFeature(const ParserState &state) const {
FeatureVector result;
feature_extractor_->ExtractFeatures(workspace_, state, &result);
return result.size() > 0 ? result.value(0) : -1;
}
void CheckStarts(const ParserState &state, const vector<int> &target) {
......@@ -48,10 +103,18 @@ class SegmentationTransitionTest : public ::testing::Test {
}
}
// The test document, parse tree, and sentence with tags and partial parses.
// The test sentence.
std::unique_ptr<Sentence> sentence_;
// Members for testing features.
std::unique_ptr<ParserFeatureExtractor> feature_extractor_;
std::unique_ptr<TaskContext> context_;
std::unique_ptr<WorkspaceRegistry> registry_;
WorkspaceSet workspace_;
std::unique_ptr<ParserTransitionSystem> transition_system_;
TermFrequencyMap label_map_;
TermFrequencyMap word_map_;
};
TEST_F(SegmentationTransitionTest, GoldNextActionTest) {
......@@ -108,4 +171,62 @@ TEST_F(SegmentationTransitionTest, DefaultActionTest) {
EXPECT_EQ(sentence_->token(4).word(), "样");
}
TEST_F(SegmentationTransitionTest, LastWordFeatureTest) {
const int unk_id = word_map_.Size();
const int outside_id = unk_id + 1;
// Prepare a parser state.
BinarySegmentState *segment_state = new BinarySegmentState();
auto state = std::unique_ptr<ParserState>(new ParserState(
sentence_.get(), segment_state, &label_map_));
// Test initial state which contains no words.
PrepareFeature("last-word(1,min-freq=2)", state.get());
EXPECT_EQ(outside_id, ComputeFeature(*state));
PrepareFeature("last-word(2,min-freq=2)", state.get());
EXPECT_EQ(outside_id, ComputeFeature(*state));
PrepareFeature("last-word(3,min-freq=2)", state.get());
EXPECT_EQ(outside_id, ComputeFeature(*state));
// Test when the state contains only one start.
segment_state->AddStart(0, state.get());
PrepareFeature("last-word(1,min-freq=2)", state.get());
EXPECT_EQ(outside_id, ComputeFeature(*state));
PrepareFeature("last-word(2,min-freq=2)", state.get());
EXPECT_EQ(outside_id, ComputeFeature(*state));
// Test when the state contains two starts, which forms a complete word and
// the start of another new word.
segment_state->AddStart(2, state.get());
EXPECT_NE(word_map_.LookupIndex("因为", unk_id), unk_id);
PrepareFeature("last-word(1)", state.get());
EXPECT_EQ(word_map_.LookupIndex("因为", unk_id), ComputeFeature(*state));
// The last-word still points to outside.
PrepareFeature("last-word(2,min-freq=2)", state.get());
EXPECT_EQ(outside_id, ComputeFeature(*state));
// Adding more starts that leads to the following words:
// 因为 ‘ ’ 有 ‘ ’
segment_state->AddStart(3, state.get());
segment_state->AddStart(4, state.get());
// Note 有 is pruned from the map since its frequency is less than 2.
EXPECT_EQ(word_map_.LookupIndex("有", unk_id), unk_id);
PrepareFeature("last-word(1,min-freq=2)", state.get());
EXPECT_EQ(unk_id, ComputeFeature(*state));
// Note that last-word(2) points to ' ' which is also a unk.
PrepareFeature("last-word(2,min-freq=2)", state.get());
EXPECT_EQ(unk_id, ComputeFeature(*state));
PrepareFeature("last-word(3,min-freq=2)", state.get());
EXPECT_EQ(word_map_.LookupIndex("因为", unk_id), ComputeFeature(*state));
// Adding two words: "这" and "样".
segment_state->AddStart(5, state.get());
segment_state->AddStart(6, state.get());
PrepareFeature("last-word(1,min-freq=2)", state.get());
EXPECT_EQ(word_map_.LookupIndex("这", unk_id), ComputeFeature(*state));
}
} // namespace syntaxnet
......@@ -795,6 +795,7 @@ DEFINE_CHAR_PROPERTY(separator, prop) {
DEFINE_CHAR_PROPERTY_AS_SET(digit,
RANGE('0', '9'),
RANGE(0x0660, 0x0669), // Arabic-Indic digits
RANGE(0x06F0, 0x06F9), // Eastern Arabic-Indic digits
)
......
......@@ -20,13 +20,19 @@ import os
import os.path
import time
import tempfile
import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from google.protobuf import text_format
from syntaxnet import sentence_pb2
from syntaxnet import graph_builder
from syntaxnet import structured_graph_builder
from syntaxnet.ops import gen_parser_ops
from syntaxnet import task_spec_pb2
flags = tf.app.flags
FLAGS = flags.FLAGS
......@@ -35,6 +41,8 @@ FLAGS = flags.FLAGS
flags.DEFINE_string('task_context', '',
'Path to a task context with inputs and parameters for '
'feature extractors.')
flags.DEFINE_string('resource_dir', '',
'Optional base directory for task context resources.')
flags.DEFINE_string('model_path', '', 'Path to model parameters.')
flags.DEFINE_string('arg_prefix', None, 'Prefix for context parameters.')
flags.DEFINE_string('graph_builder', 'greedy',
......@@ -53,16 +61,28 @@ flags.DEFINE_bool('slim_model', False,
'Whether to expect only averaged variables.')
def Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
"""Builds and evaluates a network.
def RewriteContext(task_context):
context = task_spec_pb2.TaskSpec()
with gfile.FastGFile(task_context) as fin:
text_format.Merge(fin.read(), context)
for resource in context.input:
for part in resource.part:
if part.file_pattern != '-':
part.file_pattern = os.path.join(FLAGS.resource_dir, part.file_pattern)
with tempfile.NamedTemporaryFile(delete=False) as fout:
fout.write(str(context))
return fout.name
def Eval(sess):
"""Builds and evaluates a network."""
task_context = FLAGS.task_context
if FLAGS.resource_dir:
task_context = RewriteContext(task_context)
feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
gen_parser_ops.feature_size(task_context=task_context,
arg_prefix=FLAGS.arg_prefix))
Args:
sess: tensorflow session to use
num_actions: number of possible golden actions
feature_sizes: size of each feature vector
domain_sizes: number of possible feature ids in each feature vector
embedding_dims: embedding dimension for each feature group
"""
t = time.time()
hidden_layer_sizes = map(int, FLAGS.hidden_layer_sizes.split(','))
logging.info('Building training network with parameters: feature_sizes: %s '
......@@ -86,7 +106,6 @@ def Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
arg_prefix=FLAGS.arg_prefix,
beam_size=FLAGS.beam_size,
max_steps=FLAGS.max_steps)
task_context = FLAGS.task_context
parser.AddEvaluation(task_context,
FLAGS.batch_size,
corpus_name=FLAGS.input,
......@@ -98,7 +117,7 @@ def Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
sink_documents = tf.placeholder(tf.string)
sink = gen_parser_ops.document_sink(sink_documents,
task_context=FLAGS.task_context,
task_context=task_context,
corpus_name=FLAGS.output)
t = time.time()
num_epochs = None
......@@ -136,12 +155,7 @@ def Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
def main(unused_argv):
logging.set_verbosity(logging.INFO)
with tf.Session() as sess:
feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
gen_parser_ops.feature_size(task_context=FLAGS.task_context,
arg_prefix=FLAGS.arg_prefix))
with tf.Session() as sess:
Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims)
Eval(sess)
if __name__ == '__main__':
......
......@@ -243,6 +243,7 @@ class TermFrequencyMapSetFeature : public TokenLookupSetFeature {
void Init(TaskContext *context) override;
// Number of unique values.
int64 NumValues() const override { return term_map_->Size(); }
// Special value for strings not in the map.
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include "syntaxnet/document_format.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/segmenter_utils.h"
#include "syntaxnet/utils.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
#include "tensorflow/core/lib/strings/strcat.h"
......@@ -172,13 +173,13 @@ class CoNLLSyntaxFormat : public DocumentFormat {
if (add_pos_as_attribute_) RemovePosFromAttributes(&token);
vector<string> fields(10);
fields[0] = tensorflow::strings::Printf("%d", i + 1);
fields[1] = token.word();
fields[1] = UnderscoreIfEmpty(token.word());
fields[2] = "_";
fields[3] = token.category();
fields[4] = token.tag();
fields[3] = UnderscoreIfEmpty(token.category());
fields[4] = UnderscoreIfEmpty(token.tag());
fields[5] = GetMorphAttributes(token);
fields[6] = tensorflow::strings::Printf("%d", token.head() + 1);
fields[7] = token.label();
fields[7] = UnderscoreIfEmpty(token.label());
fields[8] = "_";
fields[9] = "_";
lines.push_back(utils::Join(fields, "\t"));
......@@ -187,6 +188,11 @@ class CoNLLSyntaxFormat : public DocumentFormat {
}
private:
// Replaces empty fields with an undescore.
string UnderscoreIfEmpty(const string &field) {
return field.empty() ? "_" : field;
}
// Creates a TokenMorphology object out of a list of attribute values of the
// form: a1=v1|a2=v2|... or v1|v2|...
void AddMorphAttributes(const string &attributes, Token *token) {
......@@ -194,11 +200,7 @@ class CoNLLSyntaxFormat : public DocumentFormat {
token->MutableExtension(TokenMorphology::morphology);
vector<string> att_vals = utils::Split(attributes, '|');
for (int i = 0; i < att_vals.size(); ++i) {
vector<string> att_val = utils::Split(att_vals[i], '=');
CHECK_LE(att_val.size(), 2)
<< "Error parsing morphology features "
<< "column, must be of format "
<< "a1=v1|a2=v2|... or v1|v2|... <field>: " << attributes;
vector<string> att_val = utils::SplitOne(att_vals[i], '=');
// Format is either:
// 1) a1=v1|a2=v2..., e.g., Czech CoNLL data, or,
......@@ -268,7 +270,8 @@ class CoNLLSyntaxFormat : public DocumentFormat {
// Assumes the "fPOS" attribute, if present, is the last one.
TokenMorphology *morph =
token->MutableExtension(TokenMorphology::morphology);
if (morph->attribute().rbegin()->name() == "fPOS") {
if (morph->attribute_size() > 0 &&
morph->attribute().rbegin()->name() == "fPOS") {
morph->mutable_attribute()->RemoveLast();
}
}
......@@ -346,6 +349,45 @@ class TokenizedTextFormat : public DocumentFormat {
REGISTER_DOCUMENT_FORMAT("tokenized-text", TokenizedTextFormat);
// Reader for un-tokenized text. This reader expects every sentence to be on a
// single line. For each line in the input, a sentence proto will be created,
// where tokens are utf8 characters of that line.
//
class UntokenizedTextFormat : public TokenizedTextFormat {
public:
UntokenizedTextFormat() {}
void ConvertFromString(const string &key, const string &value,
vector<Sentence *> *sentences) override {
Sentence *sentence = new Sentence();
vector<tensorflow::StringPiece> chars;
SegmenterUtils::GetUTF8Chars(value, &chars);
int start = 0;
for (auto utf8char : chars) {
Token *token = sentence->add_token();
token->set_word(utf8char.ToString());
token->set_start(start);
start += utf8char.size();
token->set_end(start - 1);
}
if (sentence->token_size() > 0) {
sentence->set_docid(key);
sentence->set_text(value);
sentences->push_back(sentence);
} else {
// If the sentence was empty (e.g., blank lines at the beginning of a
// file), then don't save it.
delete sentence;
}
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(UntokenizedTextFormat);
};
REGISTER_DOCUMENT_FORMAT("untokenized-text", UntokenizedTextFormat);
// Text reader that attmpts to perform Penn Treebank tokenization on arbitrary
// raw text. Adapted from https://www.cis.upenn.edu/~treebank/tokenizer.sed
// by Robert MacIntyre, University of Pennsylvania, late 1995.
......
......@@ -83,6 +83,29 @@ class TextFormatsTest(test_util.TensorFlowTestCase):
self.assertEqual(' '.join([t.word for t in sentence_doc.token]),
tokenization)
def CheckUntokenizedDoc(self, sentence, words, starts, ends):
self.WriteContext('untokenized-text')
logging.info('Writing text file to: %s', self.corpus_file)
with open(self.corpus_file, 'w') as f:
f.write(sentence)
sentence, _ = gen_parser_ops.document_source(
self.context_file, batch_size=1)
with self.test_session() as sess:
sentence_doc = self.ReadNextDocument(sess, sentence)
self.assertEqual(len(sentence_doc.token), len(words))
self.assertEqual(len(sentence_doc.token), len(starts))
self.assertEqual(len(sentence_doc.token), len(ends))
for i, token in enumerate(sentence_doc.token):
self.assertEqual(token.word.encode('utf-8'), words[i])
self.assertEqual(token.start, starts[i])
self.assertEqual(token.end, ends[i])
def testUntokenized(self):
self.CheckUntokenizedDoc('一个测试', ['一', '个', '测', '试'],
[0, 3, 6, 9], [2, 5, 8, 11])
self.CheckUntokenizedDoc('Hello ', ['H', 'e', 'l', 'l', 'o', ' '],
[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5])
def testSimple(self):
self.CheckTokenization('Hello, world!', 'Hello , world !')
self.CheckTokenization('"Hello"', "`` Hello ''")
......
......@@ -95,6 +95,16 @@ std::vector<string> Split(const string &text, char delim) {
return result;
}
std::vector<string> SplitOne(const string &text, char delim) {
std::vector<string> result;
size_t split = text.find_first_of(delim);
result.push_back(text.substr(0, split));
if (split != string::npos) {
result.push_back(text.substr(split + 1));
}
return result;
}
bool IsAbsolutePath(tensorflow::StringPiece path) {
return !path.empty() && path[0] == '/';
}
......
......@@ -49,8 +49,13 @@ T ParseUsing(const string &str, T defval,
string CEscape(const string &src);
// Splits the given string on every occurrence of the given delimiter char.
std::vector<string> Split(const string &text, char delim);
// Splits the given string on the first occurrence of the given delimiter char,
// or returns the given string if the given delimiter is not found.
std::vector<string> SplitOne(const string &text, char delim);
template <typename T>
string Join(const std::vector<T> &s, const char *sep) {
string result;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment