Commit 32ab5a58 authored by calberti's avatar calberti Committed by Martin Wicke
Browse files

Adding SyntaxNet to tensorflow/models (#63)

parent 148a15fb
// Protocol buffers for serializing string<=>index dictionaries.
syntax = "proto2";
package syntaxnet;
// Serializable representation of a string=>string pair.
message StringToStringPair {
// String representing the key.
required string key = 1;
// String representing the value.
required string value = 2;
}
// Serializable representation of a string=>string mapping.
message StringToStringMap {
// Key=>value pairs.
repeated StringToStringPair pair = 1;
}
// Affix table entry, for serialization of the affix tables.
message AffixTableEntry {
// Nested message for serializing a single affix.
message AffixEntry {
// The affix as a string.
required string form = 1;
// The length of the affix (this is non-trivial to compute due to UTF-8).
required int32 length = 2;
// The ID of the affix that is one character shorter, or -1 if none exists.
required int32 shorter_id = 3;
}
// The type of affix table, as a string.
required string type = 1;
// The maximum affix length.
required int32 max_length = 2;
// The list of affixes, in order of affix ID.
repeated AffixEntry affix = 3;
}
// A light-weight proto to store vectors in binary format.
message TokenEmbedding {
required bytes token = 1; // can be word or phrase, or URL, etc.
// If available, raw count of this token in the training corpus.
optional int64 count = 3;
message Vector {
repeated float values = 1 [packed = true];
}
optional Vector vector = 2;
};
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Various utilities for handling documents.
#include <stddef.h>
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "syntaxnet/base.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/utils.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
using tensorflow::DEVICE_CPU;
using tensorflow::OpKernel;
using tensorflow::OpKernelConstruction;
using tensorflow::OpKernelContext;
using tensorflow::Tensor;
using tensorflow::TensorShape;
using tensorflow::errors::InvalidArgument;
namespace syntaxnet {
namespace {
void GetTaskContext(OpKernelConstruction *context, TaskContext *task_context) {
string file_path, data;
OP_REQUIRES_OK(context, context->GetAttr("task_context", &file_path));
OP_REQUIRES_OK(
context, ReadFileToString(tensorflow::Env::Default(), file_path, &data));
OP_REQUIRES(context,
TextFormat::ParseFromString(data, task_context->mutable_spec()),
InvalidArgument("Could not parse task context at ", file_path));
}
// Outputs the given batch of sentences as a tensor and deletes them.
void OutputDocuments(OpKernelContext *context,
vector<Sentence *> *document_batch) {
const int64 size = document_batch->size();
Tensor *output;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({size}), &output));
for (int64 i = 0; i < size; ++i) {
output->vec<string>()(i) = (*document_batch)[i]->SerializeAsString();
}
utils::STLDeleteElements(document_batch);
}
} // namespace
class DocumentSource : public OpKernel {
public:
explicit DocumentSource(OpKernelConstruction *context) : OpKernel(context) {
GetTaskContext(context, &task_context_);
string corpus_name;
OP_REQUIRES_OK(context, context->GetAttr("corpus_name", &corpus_name));
OP_REQUIRES_OK(context, context->GetAttr("batch_size", &batch_size_));
OP_REQUIRES(context, batch_size_ > 0,
InvalidArgument("invalid batch_size provided"));
corpus_.reset(new TextReader(*task_context_.GetInput(corpus_name)));
}
void Compute(OpKernelContext *context) override {
mutex_lock lock(mu_);
Sentence *document;
vector<Sentence *> document_batch;
while ((document = corpus_->Read()) != NULL) {
document_batch.push_back(document);
if (static_cast<int>(document_batch.size()) == batch_size_) {
OutputDocuments(context, &document_batch);
OutputLast(context, false);
return;
}
}
OutputDocuments(context, &document_batch);
OutputLast(context, true);
}
private:
void OutputLast(OpKernelContext *context, bool last) {
Tensor *output;
OP_REQUIRES_OK(context,
context->allocate_output(1, TensorShape({}), &output));
output->scalar<bool>()() = last;
}
// Task context used to configure this op.
TaskContext task_context_;
// mutex to synchronize access to Compute.
mutex mu_;
std::unique_ptr<TextReader> corpus_;
string documents_path_;
int batch_size_;
};
REGISTER_KERNEL_BUILDER(Name("DocumentSource").Device(DEVICE_CPU),
DocumentSource);
class DocumentSink : public OpKernel {
public:
explicit DocumentSink(OpKernelConstruction *context) : OpKernel(context) {
GetTaskContext(context, &task_context_);
string corpus_name;
OP_REQUIRES_OK(context, context->GetAttr("corpus_name", &corpus_name));
writer_.reset(new TextWriter(*task_context_.GetInput(corpus_name)));
}
void Compute(OpKernelContext *context) override {
mutex_lock lock(mu_);
auto documents = context->input(0).vec<string>();
for (int i = 0; i < documents.size(); ++i) {
Sentence document;
OP_REQUIRES(context, document.ParseFromString(documents(i)),
InvalidArgument("failed to parse sentence"));
writer_->Write(document);
}
}
private:
// Task context used to configure this op.
TaskContext task_context_;
// mutex to synchronize access to Compute.
mutex mu_;
string documents_path_;
std::unique_ptr<TextWriter> writer_;
};
REGISTER_KERNEL_BUILDER(Name("DocumentSink").Device(DEVICE_CPU),
DocumentSink);
// Sentence filter for filtering out documents where the parse trees are not
// well-formed, i.e. they contain cycles.
class WellFormedFilter : public OpKernel {
public:
explicit WellFormedFilter(OpKernelConstruction *context) : OpKernel(context) {
GetTaskContext(context, &task_context_);
OP_REQUIRES_OK(context, context->GetAttr("keep_malformed_documents",
&keep_malformed_));
}
void Compute(OpKernelContext *context) override {
auto documents = context->input(0).vec<string>();
vector<Sentence *> output_documents;
for (int i = 0; i < documents.size(); ++i) {
Sentence *document = new Sentence;
OP_REQUIRES(context, document->ParseFromString(documents(i)),
InvalidArgument("failed to parse sentence"));
if (ShouldKeep(*document)) {
output_documents.push_back(document);
} else {
delete document;
}
}
OutputDocuments(context, &output_documents);
}
private:
bool ShouldKeep(const Sentence &doc) {
vector<int> visited(doc.token_size(), -1);
for (int i = 0; i < doc.token_size(); ++i) {
// Already visited node.
if (visited[i] != -1) continue;
int t = i;
while (t != -1) {
if (visited[t] == -1) {
// If it is not visited yet, mark it.
visited[t] = i;
} else if (visited[t] < i) {
// If the index number is smaller than index and not -1, the token has
// already been visited.
break;
} else {
// Loop detected.
LOG(ERROR) << "Loop detected in document " << doc.DebugString();
return keep_malformed_;
}
t = doc.token(t).head();
}
}
return true;
}
private:
// Task context used to configure this op.
TaskContext task_context_;
bool keep_malformed_;
};
REGISTER_KERNEL_BUILDER(Name("WellFormedFilter").Device(DEVICE_CPU),
WellFormedFilter);
// Sentence filter that modifies dependency trees to make them projective. This
// could be made more efficient by looping over sentences instead of the entire
// document. Assumes that the document is well-formed in the sense of having
// no looping dependencies.
//
// Task arguments:
// bool discard_non_projective (false) : If true, discards documents with
// non-projective trees instead of projectivizing them.
class ProjectivizeFilter : public OpKernel {
public:
explicit ProjectivizeFilter(OpKernelConstruction *context)
: OpKernel(context) {
GetTaskContext(context, &task_context_);
OP_REQUIRES_OK(context, context->GetAttr("discard_non_projective",
&discard_non_projective_));
}
void Compute(OpKernelContext *context) override {
auto documents = context->input(0).vec<string>();
vector<Sentence *> output_documents;
for (int i = 0; i < documents.size(); ++i) {
Sentence *document = new Sentence;
OP_REQUIRES(context, document->ParseFromString(documents(i)),
InvalidArgument("failed to parse sentence"));
if (Process(document)) {
output_documents.push_back(document);
} else {
delete document;
}
}
OutputDocuments(context, &output_documents);
}
bool Process(Sentence *doc) {
const int num_tokens = doc->token_size();
// Left and right boundaries for arcs. The left and right ends of an arc are
// bounded by the arcs that pass over it. If an arc exceeds these bounds it
// will cross an arc passing over it, making it a non-projective arc.
vector<int> left(num_tokens);
vector<int> right(num_tokens);
// Lift the shortest non-projective arc until the document is projective.
while (true) {
// Initialize boundaries to the whole document for all arcs.
for (int i = 0; i < num_tokens; ++i) {
left[i] = -1;
right[i] = num_tokens - 1;
}
// Find left and right bounds for each token.
for (int i = 0; i < num_tokens; ++i) {
int head_index = doc->token(i).head();
// Find left and right end of arc.
int l = std::min(i, head_index);
int r = std::max(i, head_index);
// Bound all tokens under the arc.
for (int j = l + 1; j < r; ++j) {
if (left[j] < l) left[j] = l;
if (right[j] > r) right[j] = r;
}
}
// Find deepest non-projective arc.
int deepest_arc = -1;
int max_depth = -1;
// The non-projective arcs are those that exceed their bounds.
for (int i = 0; i < num_tokens; ++i) {
int head_index = doc->token(i).head();
if (head_index == -1) continue; // any crossing arc must be deeper
int l = std::min(i, head_index);
int r = std::max(i, head_index);
int left_bound = std::max(left[l], left[r]);
int right_bound = std::min(right[l], right[r]);
if (l < left_bound || r > right_bound) {
// Found non-projective arc.
if (discard_non_projective_) return false;
// Pick the deepest as the best candidate for lifting.
int depth = 0;
int j = i;
while (j != -1) {
++depth;
j = doc->token(j).head();
}
if (depth > max_depth) {
deepest_arc = i;
max_depth = depth;
}
}
}
// If there are no more non-projective arcs we are done.
if (deepest_arc == -1) return true;
// Lift non-projective arc.
int lifted_head = doc->token(doc->token(deepest_arc).head()).head();
doc->mutable_token(deepest_arc)->set_head(lifted_head);
}
}
private:
// Task context used to configure this op.
TaskContext task_context_;
// Whether or not to throw away non-projective documents.
bool discard_non_projective_;
};
REGISTER_KERNEL_BUILDER(Name("ProjectivizeFilter").Device(DEVICE_CPU),
ProjectivizeFilter);
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/document_format.h"
namespace syntaxnet {
// Component registry for document formatters.
REGISTER_CLASS_REGISTRY("document format", DocumentFormat);
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// An interface for document formats.
#ifndef $TARGETDIR_DOCUMENT_FORMAT_H__
#define $TARGETDIR_DOCUMENT_FORMAT_H__
#include <string>
#include <vector>
#include "syntaxnet/utils.h"
#include "syntaxnet/registry.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/task_context.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
namespace syntaxnet {
// A document format component converts a key/value pair from a record to one or
// more documents. The record format is used for selecting the document format
// component. A document format component can be registered with the
// REGISTER_DOCUMENT_FORMAT macro.
class DocumentFormat : public RegisterableClass<DocumentFormat> {
public:
DocumentFormat() {}
virtual ~DocumentFormat() {}
// Reads a record from the given input buffer with format specific logic.
// Returns false if no record could be read because we reached end of file.
virtual bool ReadRecord(tensorflow::io::InputBuffer *buffer,
string *record) = 0;
// Converts a key/value pair to one or more documents.
virtual void ConvertFromString(const string &key, const string &value,
vector<Sentence *> *documents) = 0;
// Converts a document to a key/value pair.
virtual void ConvertToString(const Sentence &document,
string *key, string *value) = 0;
private:
TF_DISALLOW_COPY_AND_ASSIGN(DocumentFormat);
};
#define REGISTER_DOCUMENT_FORMAT(type, component) \
REGISTER_CLASS_COMPONENT(DocumentFormat, type, component)
} // namespace syntaxnet
#endif // $TARGETDIR_DOCUMENT_FORMAT_H__
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/embedding_feature_extractor.h"
#include <vector>
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/parser_features.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/utils.h"
namespace syntaxnet {
void GenericEmbeddingFeatureExtractor::Setup(TaskContext *context) {
// Don't use version to determine how to get feature FML.
const string features = context->Get(
tensorflow::strings::StrCat(ArgPrefix(), "_", "features"), "");
const string embedding_names =
context->Get(GetParamName("embedding_names"), "");
const string embedding_dims =
context->Get(GetParamName("embedding_dims"), "");
LOG(INFO) << "Features: " << features;
LOG(INFO) << "Embedding names: " << embedding_names;
LOG(INFO) << "Embedding dims: " << embedding_dims;
embedding_fml_ = utils::Split(features, ';');
add_strings_ = context->Get(GetParamName("add_varlen_strings"), false);
embedding_names_ = utils::Split(embedding_names, ';');
for (const string &dim : utils::Split(embedding_dims, ';')) {
embedding_dims_.push_back(utils::ParseUsing<int>(dim, utils::ParseInt32));
}
}
void GenericEmbeddingFeatureExtractor::Init(TaskContext *context) {
}
vector<vector<SparseFeatures>> GenericEmbeddingFeatureExtractor::ConvertExample(
const vector<FeatureVector> &feature_vectors) const {
// Extract the features.
vector<vector<SparseFeatures>> sparse_features(feature_vectors.size());
for (size_t i = 0; i < feature_vectors.size(); ++i) {
// Convert the nlp_parser::FeatureVector to dist belief format.
sparse_features[i] =
vector<SparseFeatures>(generic_feature_extractor(i).feature_types());
for (int j = 0; j < feature_vectors[i].size(); ++j) {
const FeatureType &feature_type = *feature_vectors[i].type(j);
const FeatureValue value = feature_vectors[i].value(j);
const bool is_continuous = feature_type.name().find("continuous") == 0;
const int64 id = is_continuous ? FloatFeatureValue(value).id : value;
const int base = feature_type.base();
if (id >= 0) {
sparse_features[i][base].add_id(id);
if (is_continuous) {
sparse_features[i][base].add_weight(FloatFeatureValue(value).weight);
}
if (add_strings_) {
sparse_features[i][base].add_description(tensorflow::strings::StrCat(
feature_type.name(), "=", feature_type.GetFeatureValueName(id)));
}
}
}
}
return sparse_features;
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef $TARGETDIR_EMBEDDING_FEATURE_EXTRACTOR_H_
#define $TARGETDIR_EMBEDDING_FEATURE_EXTRACTOR_H_
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "syntaxnet/utils.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/feature_types.h"
#include "syntaxnet/parser_features.h"
#include "syntaxnet/sentence_features.h"
#include "syntaxnet/sparse.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/workspace.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
// An EmbeddingFeatureExtractor manages the extraction of features for
// embedding-based models. It wraps a sequence of underlying classes of feature
// extractors, along with associated predicate maps. Each class of feature
// extractors is associated with a name, e.g., "words", "labels", "tags".
//
// The class is split between a generic abstract version,
// GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
// signature of the ExtractFeatures method) and a typed version.
//
// The predicate maps must be initialized before use: they can be loaded using
// Read() or updated via UpdateMapsForExample.
class GenericEmbeddingFeatureExtractor {
public:
virtual ~GenericEmbeddingFeatureExtractor() {}
// Get the prefix string to put in front of all arguments, so they don't
// conflict with other embedding models.
virtual const string ArgPrefix() const = 0;
// Sets up predicate maps and embedding space names that are common for all
// embedding based feature extractors.
virtual void Setup(TaskContext *context);
virtual void Init(TaskContext *context);
// Requests workspace for the underlying feature extractors. This is
// implemented in the typed class.
virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
// Number of predicates for the embedding at a given index (vocabulary size.)
int EmbeddingSize(int index) const {
return generic_feature_extractor(index).GetDomainSize();
}
// Returns number of embedding spaces.
int NumEmbeddings() const { return embedding_dims_.size(); }
// Returns the number of features in the embedding space.
const int FeatureSize(int idx) const {
return generic_feature_extractor(idx).feature_types();
}
// Returns the dimensionality of the embedding space.
int EmbeddingDims(int index) const { return embedding_dims_[index]; }
// Accessor for embedding dims (dimensions of the embedding spaces).
const vector<int> &embedding_dims() const { return embedding_dims_; }
const vector<string> &embedding_fml() const { return embedding_fml_; }
// Get parameter name by concatenating the prefix and the original name.
string GetParamName(const string &param_name) const {
return tensorflow::strings::StrCat(ArgPrefix(), "_", param_name);
}
protected:
// Provides the generic class with access to the templated extractors. This is
// used to get the type information out of the feature extractor without
// knowing the specific calling arguments of the extractor itself.
virtual const GenericFeatureExtractor &generic_feature_extractor(
int idx) const = 0;
// Converts a vector of extracted features into
// dist_belief::SparseFeatures. Each feature in each feature vector becomes a
// single SparseFeatures. The predicates are mapped through map_fn which
// should point to either mutable_map_fn or const_map_fn depending on whether
// or not the predicate maps should be updated.
vector<vector<SparseFeatures>> ConvertExample(
const vector<FeatureVector> &feature_vectors) const;
private:
// Embedding space names for parameter sharing.
vector<string> embedding_names_;
// FML strings for each feature extractor.
vector<string> embedding_fml_;
// Size of each of the embedding spaces (maximum predicate id).
vector<int> embedding_sizes_;
// Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
vector<int> embedding_dims_;
// Whether or not to add string descriptions to converted examples.
bool add_strings_;
};
// Templated, object-specific implementation of the
// EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
// ARGS...> class that has the appropriate FeatureTraits() to ensure that
// locator type features work.
//
// Note: for backwards compatibility purposes, this always reads the FML spec
// from "<prefix>_features".
template <class EXTRACTOR, class OBJ, class... ARGS>
class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
public:
// Sets up all predicate maps, feature extractors, and flags.
void Setup(TaskContext *context) override {
GenericEmbeddingFeatureExtractor::Setup(context);
feature_extractors_.resize(embedding_fml().size());
for (int i = 0; i < embedding_fml().size(); ++i) {
feature_extractors_[i].Parse(embedding_fml()[i]);
feature_extractors_[i].Setup(context);
}
}
// Initializes resources needed by the feature extractors.
void Init(TaskContext *context) override {
GenericEmbeddingFeatureExtractor::Init(context);
for (auto &feature_extractor : feature_extractors_) {
feature_extractor.Init(context);
}
}
// Requests workspaces from the registry. Must be called after Init(), and
// before Preprocess().
void RequestWorkspaces(WorkspaceRegistry *registry) override {
for (auto &feature_extractor : feature_extractors_) {
feature_extractor.RequestWorkspaces(registry);
}
}
// Must be called on the object one state for each sentence, before any
// feature extraction (e.g., UpdateMapsForExample, ExtractSparseFeatures).
void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
for (auto &feature_extractor : feature_extractors_) {
feature_extractor.Preprocess(workspaces, obj);
}
}
// Returns a ragged array of SparseFeatures, for 1) each feature extractor
// class e, and 2) each feature f extracted by e. Underlying predicate maps
// will not be updated and so unrecognized predicates may occur. In such a
// case the SparseFeatures object associated with a given extractor class and
// feature will be empty.
vector<vector<SparseFeatures>> ExtractSparseFeatures(
const WorkspaceSet &workspaces, const OBJ &obj, ARGS... args) const {
vector<FeatureVector> features(feature_extractors_.size());
ExtractFeatures(workspaces, obj, args..., &features);
return ConvertExample(features);
}
// Extracts features using the extractors. Note that features must already
// be initialized to the correct number of feature extractors. No predicate
// mapping is applied.
void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
ARGS... args,
vector<FeatureVector> *features) const {
DCHECK(features != nullptr);
DCHECK_EQ(features->size(), feature_extractors_.size());
for (int i = 0; i < feature_extractors_.size(); ++i) {
(*features)[i].clear();
feature_extractors_[i].ExtractFeatures(workspaces, obj, args...,
&(*features)[i]);
}
}
protected:
// Provides generic access to the feature extractors.
const GenericFeatureExtractor &generic_feature_extractor(
int idx) const override {
DCHECK_LT(idx, feature_extractors_.size());
DCHECK_GE(idx, 0);
return feature_extractors_[idx];
}
private:
// Templated feature extractor class.
vector<EXTRACTOR> feature_extractors_;
};
class ParserEmbeddingFeatureExtractor
: public EmbeddingFeatureExtractor<ParserFeatureExtractor, ParserState> {
public:
explicit ParserEmbeddingFeatureExtractor(const string &arg_prefix)
: arg_prefix_(arg_prefix) {}
private:
const string ArgPrefix() const override { return arg_prefix_; }
// Prefix for context parameters.
string arg_prefix_;
};
} // namespace syntaxnet
#endif // $TARGETDIR_EMBEDDING_FEATURE_EXTRACTOR_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/feature_types.h"
#include "syntaxnet/fml_parser.h"
namespace syntaxnet {
constexpr FeatureValue GenericFeatureFunction::kNone;
GenericFeatureExtractor::GenericFeatureExtractor() {}
GenericFeatureExtractor::~GenericFeatureExtractor() {}
void GenericFeatureExtractor::Parse(const string &source) {
// Parse feature specification into descriptor.
FMLParser parser;
parser.Parse(source, mutable_descriptor());
// Initialize feature extractor from descriptor.
InitializeFeatureFunctions();
}
void GenericFeatureExtractor::InitializeFeatureTypes() {
// Register all feature types.
GetFeatureTypes(&feature_types_);
for (size_t i = 0; i < feature_types_.size(); ++i) {
FeatureType *ft = feature_types_[i];
ft->set_base(i);
// Check for feature space overflow.
double domain_size = ft->GetDomainSize();
if (domain_size < 0) {
LOG(FATAL) << "Illegal domain size for feature " << ft->name()
<< domain_size;
}
}
vector<string> types_names;
GetFeatureTypeNames(&types_names);
CHECK_EQ(feature_types_.size(), types_names.size());
}
void GenericFeatureExtractor::GetFeatureTypeNames(
vector<string> *type_names) const {
for (size_t i = 0; i < feature_types_.size(); ++i) {
FeatureType *ft = feature_types_[i];
type_names->push_back(ft->name());
}
}
FeatureValue GenericFeatureExtractor::GetDomainSize() const {
// Domain size of the set of features is equal to:
// [largest domain size of any feature types] * [number of feature types]
FeatureValue max_feature_type_dsize = 0;
for (size_t i = 0; i < feature_types_.size(); ++i) {
FeatureType *ft = feature_types_[i];
const FeatureValue feature_type_dsize = ft->GetDomainSize();
if (feature_type_dsize > max_feature_type_dsize) {
max_feature_type_dsize = feature_type_dsize;
}
}
return max_feature_type_dsize;
}
string GenericFeatureFunction::GetParameter(const string &name) const {
// Find named parameter in feature descriptor.
for (int i = 0; i < descriptor_->parameter_size(); ++i) {
if (name == descriptor_->parameter(i).name()) {
return descriptor_->parameter(i).value();
}
}
return "";
}
GenericFeatureFunction::GenericFeatureFunction() {}
GenericFeatureFunction::~GenericFeatureFunction() {
delete feature_type_;
}
int GenericFeatureFunction::GetIntParameter(const string &name,
int default_value) const {
string value = GetParameter(name);
return utils::ParseUsing<int>(value, default_value,
tensorflow::strings::safe_strto32);
}
void GenericFeatureFunction::GetFeatureTypes(
vector<FeatureType *> *types) const {
if (feature_type_ != nullptr) types->push_back(feature_type_);
}
FeatureType *GenericFeatureFunction::GetFeatureType() const {
// If a single feature type has been registered return it.
if (feature_type_ != nullptr) return feature_type_;
// Get feature types for function.
vector<FeatureType *> types;
GetFeatureTypes(&types);
// If there is exactly one feature type return this, else return null.
if (types.size() == 1) return types[0];
return nullptr;
}
} // namespace syntaxnet
This diff is collapsed.
// Protocol buffers for feature extractor.
syntax = "proto2";
package syntaxnet;
message Parameter {
optional string name = 1;
optional string value = 2;
}
// Descriptor for feature function.
message FeatureFunctionDescriptor {
// Feature function type.
required string type = 1;
// Feature function name.
optional string name = 2;
// Default argument for feature function.
optional int32 argument = 3 [default = 0];
// Named parameters for feature descriptor.
repeated Parameter parameter = 4;
// Nested sub-feature function descriptors.
repeated FeatureFunctionDescriptor feature = 7;
};
// Descriptor for feature extractor.
message FeatureExtractorDescriptor {
// Top-level feature function for extractor.
repeated FeatureFunctionDescriptor feature = 1;
};
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Common feature types for parser components.
#ifndef $TARGETDIR_FEATURE_TYPES_H_
#define $TARGETDIR_FEATURE_TYPES_H_
#include <algorithm>
#include <map>
#include <string>
#include <utility>
#include "syntaxnet/utils.h"
namespace syntaxnet {
// Use the same type for feature values as is used for predicated.
typedef int64 Predicate;
typedef Predicate FeatureValue;
// Each feature value in a feature vector has a feature type. The feature type
// is used for converting feature type and value pairs to predicate values. The
// feature type can also return names for feature values and calculate the size
// of the feature value domain. The FeatureType class is abstract and must be
// specialized for the concrete feature types.
class FeatureType {
public:
// Initializes a feature type.
explicit FeatureType(const string &name)
: name_(name), base_(0) {}
virtual ~FeatureType() {}
// Converts a feature value to a name.
virtual string GetFeatureValueName(FeatureValue value) const = 0;
// Returns the size of the feature values domain.
virtual int64 GetDomainSize() const = 0;
// Returns the feature type name.
const string &name() const { return name_; }
Predicate base() const { return base_; }
void set_base(Predicate base) { base_ = base; }
private:
// Feature type name.
string name_;
// "Base" feature value: i.e. a "slot" in a global ordering of features.
Predicate base_;
};
// Templated generic resource based feature type. This feature type delegates
// look up of feature value names to an unknown resource class, which is not
// owned. Optionally, this type can also store a mapping of extra values which
// are not in the resource.
//
// Note: this class assumes that Resource->GetFeatureValueName() will return
// successfully for values ONLY in the range [0, Resource->NumValues()) Any
// feature value not in the extra value map and not in the above range of
// Resource will result in a ERROR and return of "<INVALID>".
template<class Resource>
class ResourceBasedFeatureType : public FeatureType {
public:
// Creates a new type with given name, resource object, and a mapping of
// special values. The values must be greater or equal to
// resource->NumValues() so as to avoid collisions; this is verified with
// CHECK at creation.
ResourceBasedFeatureType(const string &name, const Resource *resource,
const map<FeatureValue, string> &values)
: FeatureType(name), resource_(resource), values_(values) {
max_value_ = resource->NumValues() - 1;
for (const auto &pair : values) {
CHECK_GE(pair.first, resource->NumValues()) << "Invalid extra value: "
<< pair.first << "," << pair.second;
max_value_ = pair.first > max_value_ ? pair.first : max_value_;
}
}
// Creates a new type with no special values.
ResourceBasedFeatureType(const string &name, const Resource *resource)
: ResourceBasedFeatureType(name, resource, {}) {}
// Returns the feature name for a given feature value. First checks the values
// map, then checks the resource to look up the name.
string GetFeatureValueName(FeatureValue value) const override {
if (values_.find(value) != values_.end()) {
return values_.find(value)->second;
}
if (value >= 0 && value < resource_->NumValues()) {
return resource_->GetFeatureValueName(value);
} else {
LOG(ERROR) << "Invalid feature value " << value << " for " << name();
return "<INVALID>";
}
}
// Returns the number of possible values for this feature type. This is the
// based on the largest value that was observed in the extra values.
FeatureValue GetDomainSize() const override { return max_value_ + 1; }
protected:
// Shared resource. Not owned.
const Resource *resource_ = nullptr;
// Maximum possible value this feature could take.
FeatureValue max_value_;
// Mapping for extra feature values not in the resource.
map<FeatureValue, string> values_;
};
// Feature type that is defined using an explicit map from FeatureValue to
// string values. This can reduce some of the boilerplate when defining
// features that generate enum values. Example usage:
//
// class BeverageSizeFeature : public FeatureFunction<Beverage>
// enum FeatureValue { SMALL, MEDIUM, LARGE }; // values for this feature
// void Init(TaskContext *context) override {
// set_feature_type(new EnumFeatureType("beverage_size",
// {{SMALL, "SMALL"}, {MEDIUM, "MEDIUM"}, {LARGE, "LARGE"}});
// }
// [...]
// };
class EnumFeatureType : public FeatureType {
public:
EnumFeatureType(const string &name,
const map<FeatureValue, string> &value_names)
: FeatureType(name), value_names_(value_names) {
for (const auto &pair : value_names) {
CHECK_GE(pair.first, 0)
<< "Invalid feature value: " << pair.first << ", " << pair.second;
domain_size_ = std::max(domain_size_, pair.first + 1);
}
}
// Returns the feature name for a given feature value.
string GetFeatureValueName(FeatureValue value) const override {
auto it = value_names_.find(value);
if (it == value_names_.end()) {
LOG(ERROR)
<< "Invalid feature value " << value << " for " << name();
return "<INVALID>";
}
return it->second;
}
// Returns the number of possible values for this feature type. This is one
// greater than the largest value in the value_names map.
FeatureValue GetDomainSize() const override { return domain_size_; }
protected:
// Maximum possible value this feature could take.
FeatureValue domain_size_ = 0;
// Names of feature values.
map<FeatureValue, string> value_names_;
};
} // namespace syntaxnet
#endif // $TARGETDIR_FEATURE_TYPES_H_
This diff is collapsed.
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Feature modeling language (fml) parser.
//
// BNF grammar for fml:
//
// <feature model> ::= { <feature extractor> }
//
// <feature extractor> ::= <extractor spec> |
// <extractor spec> '.' <feature extractor> |
// <extractor spec> '{' { <feature extractor> } '}'
//
// <extractor spec> ::= <extractor type>
// [ '(' <parameter list> ')' ]
// [ ':' <extractor name> ]
//
// <parameter list> = ( <parameter> | <argument> ) { ',' <parameter> }
//
// <parameter> ::= <parameter name> '=' <parameter value>
//
// <extractor type> ::= NAME
// <extractor name> ::= NAME | STRING
// <argument> ::= NUMBER
// <parameter name> ::= NAME
// <parameter value> ::= NUMBER | STRING | NAME
#ifndef $TARGETDIR_FML_PARSER_H_
#define $TARGETDIR_FML_PARSER_H_
#include <string>
#include "syntaxnet/utils.h"
#include "syntaxnet/feature_extractor.pb.h"
namespace syntaxnet {
class FMLParser {
public:
// Parses fml specification into feature extractor descriptor.
void Parse(const string &source, FeatureExtractorDescriptor *result);
private:
// Initializes the parser with the source text.
void Initialize(const string &source);
// Outputs error message and exits.
void Error(const string &error_message);
// Moves to the next input character.
void Next();
// Moves to the next input item.
void NextItem();
// Parses a feature descriptor.
void ParseFeature(FeatureFunctionDescriptor *result);
// Parses a parameter specification.
void ParseParameter(FeatureFunctionDescriptor *result);
// Returns true if end of source input has been reached.
bool eos() { return current_ == source_.end(); }
// Item types.
enum ItemTypes {
END = 0,
NAME = -1,
NUMBER = -2,
STRING = -3,
};
// Source text.
string source_;
// Current input position.
string::iterator current_;
// Line number for current input position.
int line_number_;
// Start position for current item.
string::iterator item_start_;
// Start position for current line.
string::iterator line_start_;
// Line number for current item.
int item_line_number_;
// Item type for current item. If this is positive it is interpreted as a
// character. If it is negative it is interpreted as an item type.
int item_type_;
// Text for current item.
string item_text_;
};
} // namespace syntaxnet
#endif // $TARGETDIR_FML_PARSER_H_
This diff is collapsed.
This diff is collapsed.
// K-best part-of-speech and dependency annotations for tokens.
syntax = "proto2";
import "syntaxnet/sentence.proto";
package syntaxnet;
// A list of alternative (k-best) syntax analyses, grouped by sentences.
message KBestSyntaxAnalyses {
extend Sentence {
optional KBestSyntaxAnalyses extension = 60366242;
}
// Alternative analyses for each sentence. Sentences are listed in the
// order visited by a SentenceIterator.
repeated KBestSyntaxAnalysesForSentence sentence = 1;
// Alternative analyses for each token.
repeated KBestSyntaxAnalysesForToken token = 2;
}
// A list of alternative (k-best) analyses for a sentence spanning from a start
// token index to an end token index. The alternative analyses are ordered by
// decreasing model score from best to worst. The first analysis is the 1-best
// analysis, which is typically also stored in the document tokens.
message KBestSyntaxAnalysesForSentence {
// First token of sentence.
optional int32 start = 1 [default = -1];
// Last token of sentence.
optional int32 end = 2 [default = -1];
// K-best analyses for the tokens in this sentence. All of the analyses in
// the list have the same "type"; e.g., k-best taggings,
// k-best {tagging+parse}s, etc.
// Note also that the type of analysis stored in this list can change
// depending on where we are in the document processing pipeline; e.g.,
// may initially be taggings, and then switch to parses. The first
// token_analysis would be the 1-best analysis, which is typically also stored
// in the document. Note: some post-processors will update the document's
// syntax trees, but will leave these unchanged.
repeated AlternativeTokenAnalysis token_analysis = 3;
}
// A list of scored alternative (k-best) analyses for a particular token. These
// are all distinct from each other and ordered by decreasing model score. The
// first is the 1-best analysis, which may or may not match the document tokens
// depending on how the k-best analyses are selected.
message KBestSyntaxAnalysesForToken {
// All token analyses in this repeated field refer to the same token.
// Each alternative analysis will contain a single entry for repeated fields
// such as head, tag, category and label.
repeated AlternativeTokenAnalysis token_analysis = 3;
}
// An alternative analysis of tokens in the document. The repeated fields
// are indexed relative to the beginning of a sentence. Fields not
// represented in the alternative analysis are assumed to be unchanged.
// Currently only alternatives for tags, categories and (labeled) dependency
// heads are supported.
// Each repeated field should either have length=0 or length=number of tokens.
message AlternativeTokenAnalysis {
// Head of this token in the dependency tree: the id of the token which has
// an arc going to this one. If it is the root token of a sentence, then it
// is set to -1.
repeated int32 head = 1;
// Part-of-speech tag for token.
repeated string tag = 2;
// Coarse-grained word category for token.
repeated string category = 3;
// Label for dependency relation between this token and its head.
repeated string label = 4;
// The score of this analysis, where bigger values typically indicate better
// quality, but there are no guarantees and there is also no pre-defined
// range.
optional double score = 5;
}
This diff is collapsed.
This diff is collapsed.
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads parser_ops shared library."""
import os.path
import tensorflow as tf
tf.load_op_library(
os.path.join(tf.resource_loader.get_data_files_path(),
'parser_ops.so'))
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment