Commit 32ab5a58 authored by calberti's avatar calberti Committed by Martin Wicke
Browse files

Adding SyntaxNet to tensorflow/models (#63)

parent 148a15fb
autoencoder/MNIST_data/*
*.pyc
[submodule "tensorflow"]
path = syntaxnet/tensorflow
url = https://github.com/tensorflow/tensorflow.git
/bazel-bin
/bazel-genfiles
/bazel-out
/bazel-tensorflow
/bazel-testlogs
/bazel-tf
/bazel-syntaxnet
This diff is collapsed.
local_repository(
name = "tf",
path = __workspace_dir__ + "/tensorflow",
)
load('//tensorflow/tensorflow:workspace.bzl', 'tf_workspace')
tf_workspace("tensorflow/", "@tf")
# Specify the minimum required Bazel version.
load("@tf//tensorflow:tensorflow.bzl", "check_version")
check_version("0.2.0")
# ===== gRPC dependencies =====
bind(
name = "libssl",
actual = "@boringssl_git//:ssl",
)
git_repository(
name = "boringssl_git",
commit = "436432d849b83ab90f18773e4ae1c7a8f148f48d",
init_submodules = True,
remote = "https://github.com/mdsteele/boringssl-bazel.git",
)
bind(
name = "zlib",
actual = "@zlib_archive//:zlib",
)
new_http_archive(
name = "zlib_archive",
build_file = "zlib.BUILD",
sha256 = "879d73d8cd4d155f31c1f04838ecd567d34bebda780156f0e82a20721b3973d5",
strip_prefix = "zlib-1.2.8",
url = "http://zlib.net/zlib128.zip",
)
# Description:
# A syntactic parser and part-of-speech tagger in TensorFlow.
package(
default_visibility = ["//visibility:private"],
features = ["-layering_check"],
)
licenses(["notice"]) # Apache 2.0
load(
"syntaxnet",
"tf_proto_library",
"tf_proto_library_py",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
)
# proto libraries
tf_proto_library(
name = "feature_extractor_proto",
srcs = ["feature_extractor.proto"],
)
tf_proto_library(
name = "sentence_proto",
srcs = ["sentence.proto"],
)
tf_proto_library_py(
name = "sentence_py_pb2",
srcs = ["sentence.proto"],
)
tf_proto_library(
name = "dictionary_proto",
srcs = ["dictionary.proto"],
)
tf_proto_library_py(
name = "dictionary_py_pb2",
srcs = ["dictionary.proto"],
)
tf_proto_library(
name = "kbest_syntax_proto",
srcs = ["kbest_syntax.proto"],
deps = [":sentence_proto"],
)
tf_proto_library(
name = "task_spec_proto",
srcs = ["task_spec.proto"],
)
tf_proto_library_py(
name = "task_spec_py_pb2",
srcs = ["task_spec.proto"],
)
tf_proto_library(
name = "sparse_proto",
srcs = ["sparse.proto"],
)
tf_proto_library_py(
name = "sparse_py_pb2",
srcs = ["sparse.proto"],
)
# cc libraries for feature extraction and parsing
cc_library(
name = "base",
hdrs = ["base.h"],
visibility = ["//visibility:public"],
deps = [
"@re2//:re2",
"@tf//google/protobuf",
"@tf//third_party/eigen3",
] + select({
"//conditions:default": [
"@tf//tensorflow/core:framework",
"@tf//tensorflow/core:lib",
],
"@tf//tensorflow:darwin": [
"@tf//tensorflow/core:framework_headers_lib",
],
}),
)
cc_library(
name = "utils",
srcs = ["utils.cc"],
hdrs = [
"utils.h",
],
deps = [
":base",
"//util/utf8:unicodetext",
],
)
cc_library(
name = "test_main",
testonly = 1,
srcs = ["test_main.cc"],
linkopts = ["-lm"],
deps = [
"@tf//tensorflow/core:lib",
"@tf//tensorflow/core:testlib",
"//external:gtest",
],
)
cc_library(
name = "document_format",
srcs = ["document_format.cc"],
hdrs = ["document_format.h"],
deps = [
":registry",
":sentence_proto",
":task_context",
],
)
cc_library(
name = "text_formats",
srcs = ["text_formats.cc"],
deps = [
":document_format",
],
alwayslink = 1,
)
cc_library(
name = "fml_parser",
srcs = ["fml_parser.cc"],
hdrs = ["fml_parser.h"],
deps = [
":feature_extractor_proto",
":utils",
],
)
cc_library(
name = "proto_io",
hdrs = ["proto_io.h"],
deps = [
":feature_extractor_proto",
":fml_parser",
":kbest_syntax_proto",
":sentence_proto",
":task_context",
],
)
cc_library(
name = "feature_extractor",
srcs = ["feature_extractor.cc"],
hdrs = [
"feature_extractor.h",
"feature_types.h",
],
deps = [
":document_format",
":feature_extractor_proto",
":kbest_syntax_proto",
":proto_io",
":sentence_proto",
":task_context",
":utils",
":workspace",
],
)
cc_library(
name = "affix",
srcs = ["affix.cc"],
hdrs = ["affix.h"],
deps = [
":dictionary_proto",
":feature_extractor",
":shared_store",
":term_frequency_map",
":utils",
":workspace",
],
)
cc_library(
name = "sentence_features",
srcs = ["sentence_features.cc"],
hdrs = ["sentence_features.h"],
deps = [
":affix",
":feature_extractor",
":registry",
],
)
cc_library(
name = "shared_store",
srcs = ["shared_store.cc"],
hdrs = ["shared_store.h"],
deps = [
":utils",
],
)
cc_library(
name = "registry",
srcs = ["registry.cc"],
hdrs = ["registry.h"],
deps = [
":utils",
],
)
cc_library(
name = "workspace",
srcs = ["workspace.cc"],
hdrs = ["workspace.h"],
deps = [
":utils",
],
)
cc_library(
name = "task_context",
srcs = ["task_context.cc"],
hdrs = ["task_context.h"],
deps = [
":task_spec_proto",
":utils",
],
)
cc_library(
name = "term_frequency_map",
srcs = ["term_frequency_map.cc"],
hdrs = ["term_frequency_map.h"],
visibility = ["//visibility:public"],
deps = [
":utils",
],
alwayslink = 1,
)
cc_library(
name = "parser_transitions",
srcs = [
"arc_standard_transitions.cc",
"parser_state.cc",
"parser_transitions.cc",
"tagger_transitions.cc",
],
hdrs = [
"parser_state.h",
"parser_transitions.h",
],
deps = [
":kbest_syntax_proto",
":registry",
":shared_store",
":task_context",
":term_frequency_map",
],
alwayslink = 1,
)
cc_library(
name = "populate_test_inputs",
testonly = 1,
srcs = ["populate_test_inputs.cc"],
hdrs = ["populate_test_inputs.h"],
deps = [
":dictionary_proto",
":sentence_proto",
":task_context",
":term_frequency_map",
":test_main",
],
)
cc_library(
name = "parser_features",
srcs = ["parser_features.cc"],
hdrs = ["parser_features.h"],
deps = [
":affix",
":feature_extractor",
":parser_transitions",
":registry",
":sentence_features",
":sentence_proto",
":task_context",
":term_frequency_map",
":workspace",
],
alwayslink = 1,
)
cc_library(
name = "embedding_feature_extractor",
srcs = ["embedding_feature_extractor.cc"],
hdrs = ["embedding_feature_extractor.h"],
deps = [
":feature_extractor",
":parser_features",
":parser_transitions",
":sparse_proto",
":task_context",
":workspace",
],
)
cc_library(
name = "sentence_batch",
srcs = ["sentence_batch.cc"],
hdrs = ["sentence_batch.h"],
deps = [
":embedding_feature_extractor",
":feature_extractor",
":parser_features",
":parser_transitions",
":sparse_proto",
":task_context",
":task_spec_proto",
":term_frequency_map",
":workspace",
],
)
cc_library(
name = "reader_ops",
srcs = [
"beam_reader_ops.cc",
"reader_ops.cc",
],
deps = [
":parser_features",
":parser_transitions",
":sentence_batch",
":sentence_proto",
":task_context",
":task_spec_proto",
],
alwayslink = 1,
)
cc_library(
name = "document_filters",
srcs = ["document_filters.cc"],
deps = [
":document_format",
":parser_features",
":parser_transitions",
":sentence_batch",
":sentence_proto",
":task_context",
":task_spec_proto",
":text_formats",
],
alwayslink = 1,
)
cc_library(
name = "lexicon_builder",
srcs = ["lexicon_builder.cc"],
deps = [
":document_format",
":parser_features",
":parser_transitions",
":sentence_batch",
":sentence_proto",
":task_context",
":task_spec_proto",
":text_formats",
],
alwayslink = 1,
)
cc_library(
name = "unpack_sparse_features",
srcs = ["unpack_sparse_features.cc"],
deps = [
":sparse_proto",
":utils",
],
)
cc_library(
name = "parser_ops_cc",
srcs = ["ops/parser_ops.cc"],
deps = [
":base",
":document_filters",
":lexicon_builder",
":reader_ops",
":unpack_sparse_features",
],
alwayslink = 1,
)
cc_binary(
name = "parser_ops.so",
linkopts = select({
"//conditions:default": ["-lm"],
"@tf//tensorflow:darwin": [],
}),
linkshared = 1,
linkstatic = 1,
deps = [
":parser_ops_cc",
],
)
# cc tests
filegroup(
name = "testdata",
srcs = [
"testdata/context.pbtxt",
"testdata/document",
"testdata/mini-training-set",
],
)
cc_test(
name = "shared_store_test",
size = "small",
srcs = ["shared_store_test.cc"],
deps = [
":shared_store",
":test_main",
],
)
cc_test(
name = "sentence_features_test",
size = "medium",
srcs = ["sentence_features_test.cc"],
deps = [
":feature_extractor",
":populate_test_inputs",
":sentence_features",
":sentence_proto",
":task_context",
":task_spec_proto",
":term_frequency_map",
":test_main",
":workspace",
],
)
cc_test(
name = "arc_standard_transitions_test",
size = "small",
srcs = ["arc_standard_transitions_test.cc"],
data = [":testdata"],
deps = [
":parser_transitions",
":populate_test_inputs",
":test_main",
],
)
cc_test(
name = "tagger_transitions_test",
size = "small",
srcs = ["tagger_transitions_test.cc"],
data = [":testdata"],
deps = [
":parser_transitions",
":populate_test_inputs",
":test_main",
],
)
cc_test(
name = "parser_features_test",
size = "small",
srcs = ["parser_features_test.cc"],
deps = [
":feature_extractor",
":parser_features",
":parser_transitions",
":populate_test_inputs",
":sentence_proto",
":task_context",
":task_spec_proto",
":term_frequency_map",
":test_main",
":workspace",
],
)
# py graph builder and trainer
tf_gen_op_libs(
op_lib_names = ["parser_ops"],
)
tf_gen_op_wrapper_py(
name = "parser_ops",
deps = [":parser_ops_op_lib"],
)
py_library(
name = "load_parser_ops_py",
srcs = ["load_parser_ops.py"],
data = [":parser_ops.so"],
)
py_library(
name = "graph_builder",
srcs = ["graph_builder.py"],
deps = [
"@tf//tensorflow:tensorflow_py",
"@tf//tensorflow/core:protos_all_py",
":load_parser_ops_py",
":parser_ops",
],
)
py_library(
name = "structured_graph_builder",
srcs = ["structured_graph_builder.py"],
deps = [
":graph_builder",
],
)
py_binary(
name = "parser_trainer",
srcs = ["parser_trainer.py"],
deps = [
":graph_builder",
":structured_graph_builder",
":task_spec_py_pb2",
],
)
py_binary(
name = "parser_eval",
srcs = ["parser_eval.py"],
deps = [
":graph_builder",
":sentence_py_pb2",
":structured_graph_builder",
],
)
py_binary(
name = "conll2tree",
srcs = ["conll2tree.py"],
deps = [
":graph_builder",
":sentence_py_pb2",
],
)
# py tests
py_test(
name = "lexicon_builder_test",
size = "small",
srcs = ["lexicon_builder_test.py"],
deps = [
":graph_builder",
":sentence_py_pb2",
":task_spec_py_pb2",
],
)
py_test(
name = "text_formats_test",
size = "small",
srcs = ["text_formats_test.py"],
deps = [
":graph_builder",
":sentence_py_pb2",
":task_spec_py_pb2",
],
)
py_test(
name = "reader_ops_test",
size = "medium",
srcs = ["reader_ops_test.py"],
data = [":testdata"],
tags = ["notsan"],
deps = [
":dictionary_py_pb2",
":graph_builder",
":sparse_py_pb2",
],
)
py_test(
name = "beam_reader_ops_test",
size = "medium",
srcs = ["beam_reader_ops_test.py"],
data = [":testdata"],
tags = ["notsan"],
deps = [
":structured_graph_builder",
],
)
py_test(
name = "graph_builder_test",
size = "medium",
srcs = ["graph_builder_test.py"],
data = [
":testdata",
],
tags = ["notsan"],
deps = [
":graph_builder",
":sparse_py_pb2",
],
)
sh_test(
name = "parser_trainer_test",
size = "medium",
srcs = ["parser_trainer_test.sh"],
data = [
":parser_eval",
":parser_trainer",
":testdata",
],
tags = ["notsan"],
)
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/affix.h"
#include <ctype.h>
#include <string.h>
#include <functional>
#include <string>
#include "syntaxnet/shared_store.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/term_frequency_map.h"
#include "syntaxnet/utils.h"
#include "syntaxnet/workspace.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/regexp.h"
#include "util/utf8/unicodetext.h"
namespace syntaxnet {
// Initial number of buckets in term and affix hash maps. This must be a power
// of two.
static const int kInitialBuckets = 1024;
// Fill factor for term and affix hash maps.
static const int kFillFactor = 2;
int TermHash(string term) {
return utils::Hash32(term.data(), term.size(), 0xDECAF);
}
// Copies a substring of a Unicode text to a string.
static void UnicodeSubstring(UnicodeText::const_iterator start,
UnicodeText::const_iterator end, string *result) {
result->clear();
result->append(start.utf8_data(), end.utf8_data() - start.utf8_data());
}
AffixTable::AffixTable(Type type, int max_length) {
type_ = type;
max_length_ = max_length;
Resize(0);
}
AffixTable::~AffixTable() { Reset(0); }
void AffixTable::Reset(int max_length) {
// Save new maximum affix length.
max_length_ = max_length;
// Delete all data.
for (size_t i = 0; i < affixes_.size(); ++i) delete affixes_[i];
affixes_.clear();
buckets_.clear();
Resize(0);
}
void AffixTable::Read(const AffixTableEntry &table_entry) {
CHECK_EQ(table_entry.type(), type_ == PREFIX ? "PREFIX" : "SUFFIX");
CHECK_GE(table_entry.max_length(), 0);
Reset(table_entry.max_length());
// First, create all affixes.
for (int affix_id = 0; affix_id < table_entry.affix_size(); ++affix_id) {
const auto &affix_entry = table_entry.affix(affix_id);
CHECK_GE(affix_entry.length(), 0);
CHECK_LE(affix_entry.length(), max_length_);
CHECK(FindAffix(affix_entry.form()) == NULL); // forbid duplicates
Affix *affix = AddNewAffix(affix_entry.form(), affix_entry.length());
CHECK_EQ(affix->id(), affix_id);
}
CHECK_EQ(affixes_.size(), table_entry.affix_size());
// Next, link the shorter affixes.
for (int affix_id = 0; affix_id < table_entry.affix_size(); ++affix_id) {
const auto &affix_entry = table_entry.affix(affix_id);
if (affix_entry.shorter_id() == -1) {
CHECK_EQ(affix_entry.length(), 1);
continue;
}
CHECK_GT(affix_entry.length(), 1);
CHECK_GE(affix_entry.shorter_id(), 0);
CHECK_LT(affix_entry.shorter_id(), affixes_.size());
Affix *affix = affixes_[affix_id];
Affix *shorter = affixes_[affix_entry.shorter_id()];
CHECK_EQ(affix->length(), shorter->length() + 1);
affix->set_shorter(shorter);
}
}
void AffixTable::Read(ProtoRecordReader *reader) {
AffixTableEntry table_entry;
TF_CHECK_OK(reader->Read(&table_entry));
Read(table_entry);
}
void AffixTable::Write(AffixTableEntry *table_entry) const {
table_entry->Clear();
table_entry->set_type(type_ == PREFIX ? "PREFIX" : "SUFFIX");
table_entry->set_max_length(max_length_);
for (const Affix *affix : affixes_) {
auto *affix_entry = table_entry->add_affix();
affix_entry->set_form(affix->form());
affix_entry->set_length(affix->length());
affix_entry->set_shorter_id(
affix->shorter() == NULL ? -1 : affix->shorter()->id());
}
}
void AffixTable::Write(ProtoRecordWriter *writer) const {
AffixTableEntry table_entry;
Write(&table_entry);
writer->Write(table_entry);
}
Affix *AffixTable::AddAffixesForWord(const char *word, size_t size) {
// The affix length is measured in characters and not bytes so we need to
// determine the length in characters.
UnicodeText text;
text.PointToUTF8(word, size);
int length = text.size();
// Determine longest affix.
int affix_len = length;
if (affix_len > max_length_) affix_len = max_length_;
if (affix_len == 0) return NULL;
// Find start and end of longest affix.
UnicodeText::const_iterator start, end;
if (type_ == PREFIX) {
start = end = text.begin();
for (int i = 0; i < affix_len; ++i) ++end;
} else {
start = end = text.end();
for (int i = 0; i < affix_len; ++i) --start;
}
// Try to find successively shorter affixes.
Affix *top = NULL;
Affix *ancestor = NULL;
string s;
while (affix_len > 0) {
// Try to find affix in table.
UnicodeSubstring(start, end, &s);
Affix *affix = FindAffix(s);
if (affix == NULL) {
// Affix not found, add new one to table.
affix = AddNewAffix(s, affix_len);
// Update ancestor chain.
if (ancestor != NULL) ancestor->set_shorter(affix);
ancestor = affix;
if (top == NULL) top = affix;
} else {
// Affix found. Update ancestor if needed and return match.
if (ancestor != NULL) ancestor->set_shorter(affix);
if (top == NULL) top = affix;
break;
}
// Next affix.
if (type_ == PREFIX) {
--end;
} else {
++start;
}
affix_len--;
}
return top;
}
Affix *AffixTable::GetAffix(int id) const {
if (id < 0 || id >= static_cast<int>(affixes_.size())) {
return NULL;
} else {
return affixes_[id];
}
}
string AffixTable::AffixForm(int id) const {
Affix *affix = GetAffix(id);
if (affix == NULL) {
return "";
} else {
return affix->form();
}
}
int AffixTable::AffixId(const string &form) const {
Affix *affix = FindAffix(form);
if (affix == NULL) {
return -1;
} else {
return affix->id();
}
}
Affix *AffixTable::AddNewAffix(const string &form, int length) {
int hash = TermHash(form);
int id = affixes_.size();
if (id > static_cast<int>(buckets_.size()) * kFillFactor) Resize(id);
int b = hash & (buckets_.size() - 1);
// Create new affix object.
Affix *affix = new Affix(id, form.c_str(), length);
affixes_.push_back(affix);
// Insert affix in bucket chain.
affix->next_ = buckets_[b];
buckets_[b] = affix;
return affix;
}
Affix *AffixTable::FindAffix(const string &form) const {
// Compute hash value for word.
int hash = TermHash(form);
// Try to find affix in hash table.
Affix *affix = buckets_[hash & (buckets_.size() - 1)];
while (affix != NULL) {
if (strcmp(affix->form_.c_str(), form.c_str()) == 0) return affix;
affix = affix->next_;
}
return NULL;
}
void AffixTable::Resize(int size_hint) {
// Compute new size for bucket array.
int new_size = kInitialBuckets;
while (new_size < size_hint) new_size *= 2;
int mask = new_size - 1;
// Distribute affixes in new buckets.
buckets_.resize(new_size);
for (size_t i = 0; i < buckets_.size(); ++i) {
buckets_[i] = NULL;
}
for (size_t i = 0; i < affixes_.size(); ++i) {
Affix *affix = affixes_[i];
int b = TermHash(affix->form_) & mask;
affix->next_ = buckets_[b];
buckets_[b] = affix;
}
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef $TARGETDIR_AFFIX_H_
#define $TARGETDIR_AFFIX_H_
#include <stddef.h>
#include <string>
#include <vector>
#include "syntaxnet/utils.h"
#include "syntaxnet/dictionary.pb.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/proto_io.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/term_frequency_map.h"
#include "syntaxnet/workspace.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
// An affix represents a prefix or suffix of a word of a certain length. Each
// affix has a unique id and a textual form. An affix also has a pointer to the
// affix that is one character shorter. This creates a chain of affixes that are
// successively shorter.
class Affix {
private:
friend class AffixTable;
Affix(int id, const char *form, int length)
: id_(id), length_(length), form_(form), shorter_(NULL), next_(NULL) {}
public:
// Returns unique id of affix.
int id() const { return id_; }
// Returns the textual representation of the affix.
string form() const { return form_; }
// Returns the length of the affix.
int length() const { return length_; }
// Gets/sets the affix that is one character shorter.
Affix *shorter() const { return shorter_; }
void set_shorter(Affix *next) { shorter_ = next; }
private:
// Affix id.
int id_;
// Length (in characters) of affix.
int length_;
// Text form of affix.
string form_;
// Pointer to affix that is one character shorter.
Affix *shorter_;
// Next affix in bucket chain.
Affix *next_;
TF_DISALLOW_COPY_AND_ASSIGN(Affix);
};
// An affix table holds all prefixes/suffixes of all the words added to the
// table up to a maximum length. The affixes are chained together to enable
// fast lookup of all affixes for a word.
class AffixTable {
public:
// Affix table type.
enum Type { PREFIX, SUFFIX };
AffixTable(Type type, int max_length);
~AffixTable();
// Resets the affix table and initialize the table for affixes of up to the
// maximum length specified.
void Reset(int max_length);
// De-serializes this from the given proto.
void Read(const AffixTableEntry &table_entry);
// De-serializes this from the given records.
void Read(ProtoRecordReader *reader);
// Serializes this to the given proto.
void Write(AffixTableEntry *table_entry) const;
// Serializes this to the given records.
void Write(ProtoRecordWriter *writer) const;
// Adds all prefixes/suffixes of the word up to the maximum length to the
// table. The longest affix is returned. The pointers in the affix can be
// used for getting shorter affixes.
Affix *AddAffixesForWord(const char *word, size_t size);
// Gets the affix information for the affix with a certain id. Returns NULL if
// there is no affix in the table with this id.
Affix *GetAffix(int id) const;
// Gets affix form from id. If the affix does not exist in the table, an empty
// string is returned.
string AffixForm(int id) const;
// Gets affix id for affix. If the affix does not exist in the table, -1 is
// returned.
int AffixId(const string &form) const;
// Returns size of the affix table.
int size() const { return affixes_.size(); }
// Returns the maximum affix length.
int max_length() const { return max_length_; }
private:
// Adds a new affix to table.
Affix *AddNewAffix(const string &form, int length);
// Finds existing affix in table.
Affix *FindAffix(const string &form) const;
// Resizes bucket array.
void Resize(int size_hint);
// Affix type (prefix or suffix).
Type type_;
// Maximum length of affix.
int max_length_;
// Index from affix ids to affix items.
vector<Affix *> affixes_;
// Buckets for word-to-affix hash map.
vector<Affix *> buckets_;
TF_DISALLOW_COPY_AND_ASSIGN(AffixTable);
};
} // namespace syntaxnet
#endif // $TARGETDIR_AFFIX_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Arc-standard transition system.
//
// This transition system has three types of actions:
// - The SHIFT action pushes the next input token to the stack and
// advances to the next input token.
// - The LEFT_ARC action adds a dependency relation from first to second token
// on the stack and removes second one.
// - The RIGHT_ARC action adds a dependency relation from second to first token
// on the stack and removes the first one.
//
// The transition system operates with parser actions encoded as integers:
// - A SHIFT action is encoded as 0.
// - A LEFT_ARC action is encoded as an odd number starting from 1.
// - A RIGHT_ARC action is encoded as an even number starting from 2.
#include <string>
#include "syntaxnet/utils.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
class ArcStandardTransitionState : public ParserTransitionState {
public:
// Clones the transition state by returning a new object.
ParserTransitionState *Clone() const override {
return new ArcStandardTransitionState();
}
// Pushes the root on the stack before using the parser state in parsing.
void Init(ParserState *state) override { state->Push(-1); }
// Adds transition state specific annotations to the document.
void AddParseToDocument(const ParserState &state, bool rewrite_root_labels,
Sentence *sentence) const override {
for (int i = 0; i < state.NumTokens(); ++i) {
Token *token = sentence->mutable_token(i);
token->set_label(state.LabelAsString(state.Label(i)));
if (state.Head(i) != -1) {
token->set_head(state.Head(i));
} else {
token->clear_head();
if (rewrite_root_labels) {
token->set_label(state.LabelAsString(state.RootLabel()));
}
}
}
}
// Whether a parsed token should be considered correct for evaluation.
bool IsTokenCorrect(const ParserState &state, int index) const override {
return state.GoldHead(index) == state.Head(index);
}
// Returns a human readable string representation of this state.
string ToString(const ParserState &state) const override {
string str;
str.append("[");
for (int i = state.StackSize() - 1; i >= 0; --i) {
const string &word = state.GetToken(state.Stack(i)).word();
if (i != state.StackSize() - 1) str.append(" ");
if (word == "") {
str.append(ParserState::kRootLabel);
} else {
str.append(word);
}
}
str.append("]");
for (int i = state.Next(); i < state.NumTokens(); ++i) {
tensorflow::strings::StrAppend(&str, " ", state.GetToken(i).word());
}
return str;
}
};
class ArcStandardTransitionSystem : public ParserTransitionSystem {
public:
// Action types for the arc-standard transition system.
enum ParserActionType {
SHIFT = 0,
LEFT_ARC = 1,
RIGHT_ARC = 2,
};
// The SHIFT action uses the same value as the corresponding action type.
static ParserAction ShiftAction() { return SHIFT; }
// The LEFT_ARC action converts the label to an odd number greater or equal
// to 1.
static ParserAction LeftArcAction(int label) { return 1 + (label << 1); }
// The RIGHT_ARC action converts the label to an even number greater or equal
// to 2.
static ParserAction RightArcAction(int label) {
return 1 + ((label << 1) | 1);
}
// Extracts the action type from a given parser action.
static ParserActionType ActionType(ParserAction action) {
return static_cast<ParserActionType>(action < 1 ? action
: 1 + (~action & 1));
}
// Extracts the label from a given parser action. If the action is SHIFT,
// returns -1.
static int Label(ParserAction action) {
return action < 1 ? -1 : (action - 1) >> 1;
}
// Returns the number of action types.
int NumActionTypes() const override { return 3; }
// Returns the number of possible actions.
int NumActions(int num_labels) const override { return 1 + 2 * num_labels; }
// The method returns the default action for a given state.
ParserAction GetDefaultAction(const ParserState &state) const override {
// If there are further tokens available in the input then Shift.
if (!state.EndOfInput()) return ShiftAction();
// Do a "reduce".
return RightArcAction(2);
}
// Returns the next gold action for a given state according to the
// underlying annotated sentence.
ParserAction GetNextGoldAction(const ParserState &state) const override {
// If the stack contains less than 2 tokens, the only valid parser action is
// shift.
if (state.StackSize() < 2) {
DCHECK(!state.EndOfInput());
return ShiftAction();
}
// If the second token on the stack is the head of the first one,
// return a right arc action.
if (state.GoldHead(state.Stack(0)) == state.Stack(1) &&
DoneChildrenRightOf(state, state.Stack(0))) {
const int gold_label = state.GoldLabel(state.Stack(0));
return RightArcAction(gold_label);
}
// If the first token on the stack is the head of the second one,
// return a left arc action.
if (state.GoldHead(state.Stack(1)) == state.Top()) {
const int gold_label = state.GoldLabel(state.Stack(1));
return LeftArcAction(gold_label);
}
// Otherwise, shift.
return ShiftAction();
}
// Determines if a token has any children to the right in the sentence.
// Arc standard is a bottom-up parsing method and has to finish all sub-trees
// first.
static bool DoneChildrenRightOf(const ParserState &state, int head) {
int index = state.Next();
int num_tokens = state.sentence().token_size();
while (index < num_tokens) {
// Check if the token at index is the child of head.
int actual_head = state.GoldHead(index);
if (actual_head == head) return false;
// If the head of the token at index is to the right of it there cannot be
// any children in-between, so we can skip forward to the head. Note this
// is only true for projective trees.
if (actual_head > index) {
index = actual_head;
} else {
++index;
}
}
return true;
}
// Checks if the action is allowed in a given parser state.
bool IsAllowedAction(ParserAction action,
const ParserState &state) const override {
switch (ActionType(action)) {
case SHIFT:
return IsAllowedShift(state);
case LEFT_ARC:
return IsAllowedLeftArc(state);
case RIGHT_ARC:
return IsAllowedRightArc(state);
}
return false;
}
// Returns true if a shift is allowed in the given parser state.
bool IsAllowedShift(const ParserState &state) const {
// We can shift if there are more input tokens.
return !state.EndOfInput();
}
// Returns true if a left-arc is allowed in the given parser state.
bool IsAllowedLeftArc(const ParserState &state) const {
// Left-arc requires two or more tokens on the stack but the first token
// is the root an we do not want and left arc to the root.
return state.StackSize() > 2;
}
// Returns true if a right-arc is allowed in the given parser state.
bool IsAllowedRightArc(const ParserState &state) const {
// Right arc requires three or more tokens on the stack.
return state.StackSize() > 1;
}
// Performs the specified action on a given parser state, without adding the
// action to the state's history.
void PerformActionWithoutHistory(ParserAction action,
ParserState *state) const override {
switch (ActionType(action)) {
case SHIFT:
PerformShift(state);
break;
case LEFT_ARC:
PerformLeftArc(state, Label(action));
break;
case RIGHT_ARC:
PerformRightArc(state, Label(action));
break;
}
}
// Makes a shift by pushing the next input token on the stack and moving to
// the next position.
void PerformShift(ParserState *state) const {
DCHECK(IsAllowedShift(*state));
state->Push(state->Next());
state->Advance();
}
// Makes a left-arc between the two top tokens on stack and pops the second
// token on stack.
void PerformLeftArc(ParserState *state, int label) const {
DCHECK(IsAllowedLeftArc(*state));
int s0 = state->Pop();
state->AddArc(state->Pop(), s0, label);
state->Push(s0);
}
// Makes a right-arc between the two top tokens on stack and pops the stack.
void PerformRightArc(ParserState *state, int label) const {
DCHECK(IsAllowedRightArc(*state));
int s0 = state->Pop();
int s1 = state->Pop();
state->AddArc(s0, s1, label);
state->Push(s1);
}
// We are in a deterministic state when we either reached the end of the input
// or reduced everything from the stack.
bool IsDeterministicState(const ParserState &state) const override {
return state.StackSize() < 2 && !state.EndOfInput();
}
// We are in a final state when we reached the end of the input and the stack
// is empty.
bool IsFinalState(const ParserState &state) const override {
return state.EndOfInput() && state.StackSize() < 2;
}
// Returns a string representation of a parser action.
string ActionAsString(ParserAction action,
const ParserState &state) const override {
switch (ActionType(action)) {
case SHIFT:
return "SHIFT";
case LEFT_ARC:
return "LEFT_ARC(" + state.LabelAsString(Label(action)) + ")";
case RIGHT_ARC:
return "RIGHT_ARC(" + state.LabelAsString(Label(action)) + ")";
}
return "UNKNOWN";
}
// Returns a new transition state to be used to enhance the parser state.
ParserTransitionState *NewTransitionState(bool training_mode) const override {
return new ArcStandardTransitionState();
}
};
REGISTER_TRANSITION_SYSTEM("arc-standard", ArcStandardTransitionSystem);
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <gmock/gmock.h>
#include "syntaxnet/utils.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/populate_test_inputs.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
class ArcStandardTransitionTest : public ::testing::Test {
public:
ArcStandardTransitionTest()
: transition_system_(ParserTransitionSystem::Create("arc-standard")) {}
protected:
// Creates a label map and a tag map for testing based on the given
// document and initializes the transition system appropriately.
void SetUpForDocument(const Sentence &document) {
input_label_map_ = context_.GetInput("label-map", "text", "");
transition_system_->Setup(&context_);
PopulateTestInputs::Defaults(document).Populate(&context_);
label_map_.Load(TaskContext::InputFile(*input_label_map_),
0 /* minimum frequency */,
-1 /* maximum number of terms */);
transition_system_->Init(&context_);
}
// Creates a cloned state from a sentence in order to test that cloning
// works correctly for the new parser states.
ParserState *NewClonedState(Sentence *sentence) {
ParserState state(sentence, transition_system_->NewTransitionState(
true /* training mode */),
&label_map_);
return state.Clone();
}
// Performs gold transitions and check that the labels and heads recorded
// in the parser state match gold heads and labels.
void GoldParse(Sentence *sentence) {
ParserState *state = NewClonedState(sentence);
LOG(INFO) << "Initial parser state: " << state->ToString();
while (!transition_system_->IsFinalState(*state)) {
ParserAction action = transition_system_->GetNextGoldAction(*state);
EXPECT_TRUE(transition_system_->IsAllowedAction(action, *state));
LOG(INFO) << "Performing action: "
<< transition_system_->ActionAsString(action, *state);
transition_system_->PerformActionWithoutHistory(action, state);
LOG(INFO) << "Parser state: " << state->ToString();
}
for (int i = 0; i < sentence->token_size(); ++i) {
EXPECT_EQ(state->GoldLabel(i), state->Label(i));
EXPECT_EQ(state->GoldHead(i), state->Head(i));
}
delete state;
}
// Always takes the default action, and verifies that this leads to
// a final state through a sequence of allowed actions.
void DefaultParse(Sentence *sentence) {
ParserState *state = NewClonedState(sentence);
LOG(INFO) << "Initial parser state: " << state->ToString();
while (!transition_system_->IsFinalState(*state)) {
ParserAction action = transition_system_->GetDefaultAction(*state);
EXPECT_TRUE(transition_system_->IsAllowedAction(action, *state));
LOG(INFO) << "Performing action: "
<< transition_system_->ActionAsString(action, *state);
transition_system_->PerformActionWithoutHistory(action, state);
LOG(INFO) << "Parser state: " << state->ToString();
}
delete state;
}
TaskContext context_;
TaskInput *input_label_map_ = nullptr;
TermFrequencyMap label_map_;
std::unique_ptr<ParserTransitionSystem> transition_system_;
};
TEST_F(ArcStandardTransitionTest, SingleSentenceDocumentTest) {
string document_text;
Sentence document;
TF_CHECK_OK(ReadFileToString(
tensorflow::Env::Default(),
"syntaxnet/testdata/document",
&document_text));
LOG(INFO) << "see doc\n:" << document_text;
CHECK(TextFormat::ParseFromString(document_text, &document));
SetUpForDocument(document);
GoldParse(&document);
DefaultParse(&document);
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef $TARGETDIR_BASE_H_
#define $TARGETDIR_BASE_H_
#include <functional>
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/default/integral_types.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
using tensorflow::int32;
using tensorflow::int64;
using tensorflow::uint64;
using tensorflow::uint32;
using tensorflow::uint32;
using tensorflow::protobuf::TextFormat;
using tensorflow::mutex_lock;
using tensorflow::mutex;
using std::map;
using std::pair;
using std::vector;
using std::unordered_map;
using std::unordered_set;
typedef signed int char32;
using tensorflow::StringPiece;
using std::string;
// namespace syntaxnet
#endif // $TARGETDIR_BASE_H_
This diff is collapsed.
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for beam_reader_ops."""
import os.path
import time
import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.platform import logging
from syntaxnet import structured_graph_builder
from syntaxnet.ops import gen_parser_ops
FLAGS = tf.app.flags.FLAGS
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
def setUp(self):
# Creates a task context with the correct testing paths.
initial_task_context = os.path.join(
FLAGS.test_srcdir,
'syntaxnet/'
'testdata/context.pbtxt')
self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
with open(initial_task_context, 'r') as fin:
with open(self._task_context, 'w') as fout:
fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir)
.replace('OUTPATH', FLAGS.test_tmpdir))
# Creates necessary term maps.
with self.test_session() as sess:
gen_parser_ops.lexicon_builder(task_context=self._task_context,
corpus_name='training-corpus').run()
self._num_features, self._num_feature_ids, _, self._num_actions = (
sess.run(gen_parser_ops.feature_size(task_context=self._task_context,
arg_prefix='brain_parser')))
def MakeGraph(self,
max_steps=10,
beam_size=2,
batch_size=1,
**kwargs):
"""Constructs a structured learning graph."""
assert max_steps > 0, 'Empty network not supported.'
logging.info('MakeGraph + %s', kwargs)
with self.test_session(graph=tf.Graph()) as sess:
feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
gen_parser_ops.feature_size(task_context=self._task_context))
embedding_dims = [8, 8, 8]
hidden_layer_sizes = []
learning_rate = 0.01
builder = structured_graph_builder.StructuredGraphBuilder(
num_actions,
feature_sizes,
domain_sizes,
embedding_dims,
hidden_layer_sizes,
seed=1,
max_steps=max_steps,
beam_size=beam_size,
gate_gradients=True,
use_locking=True,
use_averaging=False,
check_parameters=False,
**kwargs)
builder.AddTraining(self._task_context,
batch_size,
learning_rate=learning_rate,
decay_steps=1000,
momentum=0.9,
corpus_name='training-corpus')
builder.AddEvaluation(self._task_context,
batch_size,
evaluation_max_steps=25,
corpus_name=None)
builder.training['inits'] = tf.group(*builder.inits.values(), name='inits')
return builder
def Train(self, **kwargs):
with self.test_session(graph=tf.Graph()) as sess:
max_steps = 3
batch_size = 3
beam_size = 3
builder = (
self.MakeGraph(
max_steps=max_steps, beam_size=beam_size,
batch_size=batch_size, **kwargs))
logging.info('params: %s', builder.params.keys())
logging.info('variables: %s', builder.variables.keys())
t = builder.training
sess.run(t['inits'])
costs = []
gold_slots = []
alive_steps_vector = []
every_n = 5
walltime = time.time()
for step in range(10):
if step > 0 and step % every_n == 0:
new_walltime = time.time()
logging.info(
'Step: %d <cost>: %f <gold_slot>: %f <alive_steps>: %f <iter '
'time>: %f ms',
step, sum(costs[-every_n:]) / float(every_n),
sum(gold_slots[-every_n:]) / float(every_n),
sum(alive_steps_vector[-every_n:]) / float(every_n),
1000 * (new_walltime - walltime) / float(every_n))
walltime = new_walltime
cost, gold_slot, alive_steps, _ = sess.run(
[t['cost'], t['gold_slot'], t['alive_steps'], t['train_op']])
costs.append(cost)
gold_slots.append(gold_slot.mean())
alive_steps_vector.append(alive_steps.mean())
if builder._only_train:
trainable_param_names = [
k for k in builder.params if k in builder._only_train]
else:
trainable_param_names = builder.params.keys()
if builder._use_averaging:
for v in trainable_param_names:
avg = builder.variables['%s_avg_var' % v].eval()
tf.assign(builder.params[v], avg).eval()
# Reset for pseudo eval.
costs = []
gold_slots = []
alive_stepss = []
for step in range(10):
cost, gold_slot, alive_steps = sess.run(
[t['cost'], t['gold_slot'], t['alive_steps']])
costs.append(cost)
gold_slots.append(gold_slot.mean())
alive_stepss.append(alive_steps.mean())
logging.info(
'Pseudo eval: <cost>: %f <gold_slot>: %f <alive_steps>: %f',
sum(costs[-every_n:]) / float(every_n),
sum(gold_slots[-every_n:]) / float(every_n),
sum(alive_stepss[-every_n:]) / float(every_n))
def PathScores(self, iterations, beam_size, max_steps, batch_size):
with self.test_session(graph=tf.Graph()) as sess:
t = self.MakeGraph(beam_size=beam_size, max_steps=max_steps,
batch_size=batch_size).training
sess.run(t['inits'])
all_path_scores = []
beam_path_scores = []
for i in range(iterations):
logging.info('run %d', i)
tensors = (
sess.run(
[t['alive_steps'], t['concat_scores'],
t['all_path_scores'], t['beam_path_scores'],
t['indices'], t['path_ids']]))
logging.info('alive for %s, all_path_scores and beam_path_scores, '
'indices and path_ids:'
'\n%s\n%s\n%s\n%s',
tensors[0], tensors[2], tensors[3], tensors[4], tensors[5])
logging.info('diff:\n%s', tensors[2] - tensors[3])
all_path_scores.append(tensors[2])
beam_path_scores.append(tensors[3])
return all_path_scores, beam_path_scores
def testParseUntilNotAlive(self):
"""Ensures that the 'alive' condition works in the Cond ops."""
with self.test_session(graph=tf.Graph()) as sess:
t = self.MakeGraph(batch_size=3, beam_size=2, max_steps=5).training
sess.run(t['inits'])
for i in range(5):
logging.info('run %d', i)
tf_alive = t['alive'].eval()
self.assertFalse(any(tf_alive))
def testParseMomentum(self):
"""Ensures that Momentum training can be done using the gradients."""
self.Train()
self.Train(model_cost='perceptron_loss')
self.Train(model_cost='perceptron_loss',
only_train='softmax_weight,softmax_bias', softmax_init=0)
self.Train(only_train='softmax_weight,softmax_bias', softmax_init=0)
def testPathScoresAgree(self):
"""Ensures that path scores computed in the beam are same in the net."""
all_path_scores, beam_path_scores = self.PathScores(
iterations=1, beam_size=130, max_steps=5, batch_size=1)
self.assertArrayNear(all_path_scores[0], beam_path_scores[0], 1e-6)
def testBatchPathScoresAgree(self):
"""Ensures that path scores computed in the beam are same in the net."""
all_path_scores, beam_path_scores = self.PathScores(
iterations=1, beam_size=130, max_steps=5, batch_size=22)
self.assertArrayNear(all_path_scores[0], beam_path_scores[0], 1e-6)
def testBatchOneStepPathScoresAgree(self):
"""Ensures that path scores computed in the beam are same in the net."""
all_path_scores, beam_path_scores = self.PathScores(
iterations=1, beam_size=130, max_steps=1, batch_size=22)
self.assertArrayNear(all_path_scores[0], beam_path_scores[0], 1e-6)
if __name__ == '__main__':
googletest.main()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A program to generate ASCII trees from conll files."""
import collections
import asciitree
import tensorflow as tf
import syntaxnet.load_parser_ops
from tensorflow.python.platform import logging
from syntaxnet import sentence_pb2
from syntaxnet.ops import gen_parser_ops
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('task_context',
'syntaxnet/models/parsey_mcparseface/context.pbtxt',
'Path to a task context with inputs and parameters for '
'feature extractors.')
flags.DEFINE_string('corpus_name', 'stdin-conll',
'Path to a task context with inputs and parameters for '
'feature extractors.')
def to_dict(sentence):
"""Builds a dictionary representing the parse tree of a sentence.
Args:
sentence: Sentence protocol buffer to represent.
Returns:
Dictionary mapping tokens to children.
"""
token_str = ['%s %s %s' % (token.word, token.tag, token.label)
for token in sentence.token]
children = [[] for token in sentence.token]
root = -1
for i in range(0, len(sentence.token)):
token = sentence.token[i]
if token.head == -1:
root = i
else:
children[token.head].append(i)
def _get_dict(i):
d = collections.OrderedDict()
for c in children[i]:
d[token_str[c]] = _get_dict(c)
return d
tree = collections.OrderedDict()
tree[token_str[root]] = _get_dict(root)
return tree
def main(unused_argv):
logging.set_verbosity(logging.INFO)
with tf.Session() as sess:
src = gen_parser_ops.document_source(batch_size=32,
corpus_name=FLAGS.corpus_name,
task_context=FLAGS.task_context)
sentence = sentence_pb2.Sentence()
while True:
documents, finished = sess.run(src)
logging.info('Read %d documents', len(documents))
for d in documents:
sentence.ParseFromString(d)
tr = asciitree.LeftAligned()
d = to_dict(sentence)
print 'Input: %s' % sentence.text
print 'Parse:'
print tr(d)
if finished:
break
if __name__ == '__main__':
tf.app.run()
Parameter {
name: 'brain_parser_embedding_dims'
value: '64;32;32'
}
Parameter {
name: 'brain_parser_features'
value: 'input.word input(1).word input(2).word input(3).word stack.word stack(1).word stack(2).word stack(3).word stack.child(1).word stack.child(1).sibling(-1).word stack.child(-1).word stack.child(-1).sibling(1).word stack(1).child(1).word stack(1).child(1).sibling(-1).word stack(1).child(-1).word stack(1).child(-1).sibling(1).word stack.child(2).word stack.child(-2).word stack(1).child(2).word stack(1).child(-2).word;input.tag input(1).tag input(2).tag input(3).tag stack.tag stack(1).tag stack(2).tag stack(3).tag stack.child(1).tag stack.child(1).sibling(-1).tag stack.child(-1).tag stack.child(-1).sibling(1).tag stack(1).child(1).tag stack(1).child(1).sibling(-1).tag stack(1).child(-1).tag stack(1).child(-1).sibling(1).tag stack.child(2).tag stack.child(-2).tag stack(1).child(2).tag stack(1).child(-2).tag;stack.child(1).label stack.child(1).sibling(-1).label stack.child(-1).label stack.child(-1).sibling(1).label stack(1).child(1).label stack(1).child(1).sibling(-1).label stack(1).child(-1).label stack(1).child(-1).sibling(1).label stack.child(2).label stack.child(-2).label stack(1).child(2).label stack(1).child(-2).label'
}
Parameter {
name: 'brain_parser_embedding_names'
value: 'words;tags;labels'
}
Parameter {
name: 'brain_parser_scoring'
value: 'default'
}
Parameter {
name: 'brain_pos_transition_system'
value: 'tagger'
}
Parameter {
name: 'brain_pos_embedding_dims'
value: '64;4;8;8'
}
Parameter {
name: 'brain_pos_features'
value: 'stack(3).word stack(2).word stack(1).word stack.word input.word input(1).word input(2).word input(3).word;input.digit input.hyphen;stack.suffix(length=2) input.suffix(length=2) input(1).suffix(length=2);stack.prefix(length=2) input.prefix(length=2) input(1).prefix(length=2)'
}
Parameter {
name: 'brain_pos_embedding_names'
value: 'words;other;suffix;prefix'
}
input {
name: 'training-corpus'
record_format: 'conll-sentence'
Part {
file_pattern: '<your-dataset>/treebank-train.trees.conll'
}
}
input {
name: 'tuning-corpus'
record_format: 'conll-sentence'
Part {
file_pattern: '<your-dataset>/dev.conll'
}
}
input {
name: 'dev-corpus'
record_format: 'conll-sentence'
Part {
file_pattern: '<your-dataset>/test.conll'
}
}
input {
name: 'tagged-training-corpus'
creator: 'brain_pos/greedy'
record_format: 'conll-sentence'
}
input {
name: 'tagged-tuning-corpus'
creator: 'brain_pos/greedy'
record_format: 'conll-sentence'
}
input {
name: 'tagged-dev-corpus'
creator: 'brain_pos/greedy'
record_format: 'conll-sentence'
}
input {
name: 'label-map'
creator: 'brain_pos/greedy'
}
input {
name: 'word-map'
creator: 'brain_pos/greedy'
}
input {
name: 'lcword-map'
creator: 'brain_pos/greedy'
}
input {
name: 'tag-map'
creator: 'brain_pos/greedy'
}
input {
name: 'category-map'
creator: 'brain_pos/greedy'
}
input {
name: 'prefix-table'
creator: 'brain_pos/greedy'
}
input {
name: 'suffix-table'
creator: 'brain_pos/greedy'
}
input {
name: 'tag-to-category'
creator: 'brain_pos/greedy'
}
input {
name: 'projectivized-training-corpus'
creator: 'brain_parser/greedy'
record_format: 'conll-sentence'
}
input {
name: 'parsed-training-corpus'
creator: 'brain_parser/greedy'
record_format: 'conll-sentence'
}
input {
name: 'parsed-tuning-corpus'
creator: 'brain_parser/greedy'
record_format: 'conll-sentence'
}
input {
name: 'parsed-dev-corpus'
creator: 'brain_parser/greedy'
record_format: 'conll-sentence'
}
input {
name: 'beam-parsed-training-corpus'
creator: 'brain_parser/structured'
record_format: 'conll-sentence'
}
input {
name: 'beam-parsed-tuning-corpus'
creator: 'brain_parser/structured'
record_format: 'conll-sentence'
}
input {
name: 'beam-parsed-dev-corpus'
creator: 'brain_parser/structured'
record_format: 'conll-sentence'
}
input {
name: 'stdin'
record_format: 'english-text'
Part {
file_pattern: '-'
}
}
input {
name: 'stdin-conll'
record_format: 'conll-sentence'
Part {
file_pattern: '-'
}
}
input {
name: 'stdout-conll'
record_format: 'conll-sentence'
Part {
file_pattern: '-'
}
}
#!/bin/bash
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# A script that runs a tokenizer, a part-of-speech tagger and a dependency
# parser on an English text file, with one sentence per line.
#
# Example usage:
# echo "Parsey McParseface is my favorite parser!" | syntaxnet/demo.sh
# To run on a conll formatted file, add the --conll command line argument.
#
PARSER_EVAL=bazel-bin/syntaxnet/parser_eval
MODEL_DIR=syntaxnet/models/parsey_mcparseface
[[ "$1" == "--conll" ]] && INPUT_FORMAT=stdin-conll || INPUT_FORMAT=stdin
$PARSER_EVAL \
--input=$INPUT_FORMAT \
--output=stdout-conll \
--hidden_layer_sizes=64 \
--arg_prefix=brain_tagger \
--graph_builder=structured \
--task_context=$MODEL_DIR/context.pbtxt \
--model_path=$MODEL_DIR/tagger-params \
--slim_model \
--batch_size=1024 \
--alsologtostderr \
| \
$PARSER_EVAL \
--input=stdin-conll \
--output=stdout-conll \
--hidden_layer_sizes=512,512 \
--arg_prefix=brain_parser \
--graph_builder=structured \
--task_context=$MODEL_DIR/context.pbtxt \
--model_path=$MODEL_DIR/parser-params \
--slim_model \
--batch_size=1024 \
--alsologtostderr \
| \
bazel-bin/syntaxnet/conll2tree \
--task_context=$MODEL_DIR/context.pbtxt \
--alsologtostderr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment