Commit 32ab5a58 authored by calberti's avatar calberti Committed by Martin Wicke
Browse files

Adding SyntaxNet to tensorflow/models (#63)

parent 148a15fb
46
punct 243160
prep 194627
pobj 186958
det 170592
nsubj 144821
nn 144800
amod 117242
ROOT 90592
dobj 88551
aux 76523
advmod 72893
conj 59384
cc 57532
num 36350
poss 35117
dep 34986
ccomp 29470
cop 25991
mark 25141
xcomp 25111
rcmod 16234
auxpass 15740
advcl 14996
possessive 14866
nsubjpass 14133
pcomp 12488
appos 11112
partmod 11106
neg 11090
number 10658
prt 7123
quantmod 6653
tmod 5418
infmod 5134
npadvmod 3213
parataxis 3012
mwe 2793
expl 2712
iobj 1642
acomp 1632
discourse 1381
csubj 1225
predet 1160
preconj 749
goeswith 146
csubjpass 41
49
NN 285194
IN 228165
DT 179147
NNP 175147
JJ 125667
NNS 115732
, 97481
. 85938
RB 78513
VB 63952
CC 57554
VBD 56635
CD 55674
PRP 55244
VBZ 48126
VBN 44458
VBG 34524
VBP 33669
TO 28772
MD 22364
PRP$ 20706
HYPH 18526
POS 14905
`` 12193
'' 12154
WDT 10267
: 8713
$ 7993
WP 7336
RP 7335
WRB 6634
JJR 6295
NNPS 5917
-RRB- 3904
-LRB- 3840
JJS 3596
RBR 3186
EX 2733
UH 1521
RBS 1467
PDT 1271
FW 928
NFP 844
SYM 652
ADD 476
LS 392
WP$ 332
GW 184
AFX 42
This source diff could not be displayed because it is too large. You can view the blob instead.
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
namespace syntaxnet {
// -----------------------------------------------------------------------------
REGISTER_OP("GoldParseReader")
.Output("features: feature_size * string")
.Output("num_epochs: int32")
.Output("gold_actions: int32")
.Attr("task_context: string")
.Attr("feature_size: int")
.Attr("batch_size: int")
.Attr("corpus_name: string='documents'")
.Attr("arg_prefix: string='brain_parser'")
.SetIsStateful()
.Doc(R"doc(
Reads sentences, parses them, and returns (gold action, feature) pairs.
features: features firing at the current parser state, encoded as
dist_belief.SparseFeatures protocol buffers.
num_epochs: number of times this reader went over the training corpus.
gold_actions: action to perform at the current parser state.
task_context: file path at which to read the task context.
feature_size: number of feature outputs emitted by this reader.
batch_size: number of sentences to parse at a time.
corpus_name: name of task input in the task context to read parses from.
arg_prefix: prefix for context parameters.
)doc");
REGISTER_OP("DecodedParseReader")
.Input("transition_scores: float")
.Output("features: feature_size * string")
.Output("num_epochs: int32")
.Output("eval_metrics: int32")
.Output("documents: string")
.Attr("task_context: string")
.Attr("feature_size: int")
.Attr("batch_size: int")
.Attr("corpus_name: string='documents'")
.Attr("arg_prefix: string='brain_parser'")
.SetIsStateful()
.Doc(R"doc(
Reads sentences and parses them taking parsing transitions based on the
input transition scores.
transition_scores: scores for every transition from the current parser state.
features: features firing at the current parser state encoded as
dist_belief.SparseFeatures protocol buffers.
num_epochs: number of times this reader went over the training corpus.
eval_metrics: token counts used to compute evaluation metrics.
task_context: file path at which to read the task context.
feature_size: number of feature outputs emitted by this reader.
batch_size: number of sentences to parse at a time.
corpus_name: name of task input in the task context to read parses from.
arg_prefix: prefix for context parameters.
)doc");
REGISTER_OP("BeamParseReader")
.Output("features: feature_size * string")
.Output("beam_state: int64")
.Output("num_epochs: int32")
.Attr("task_context: string")
.Attr("feature_size: int")
.Attr("beam_size: int")
.Attr("batch_size: int=1")
.Attr("corpus_name: string='documents'")
.Attr("allow_feature_weights: bool=true")
.Attr("arg_prefix: string='brain_parser'")
.Attr("continue_until_all_final: bool=false")
.Attr("always_start_new_sentences: bool=false")
.SetIsStateful()
.Doc(R"doc(
Reads sentences and creates a beam parser.
features: features firing at the initial parser state encoded as
dist_belief.SparseFeatures protocol buffers.
beam_state: beam state handle.
task_context: file path at which to read the task context.
feature_size: number of feature outputs emitted by this reader.
beam_size: limit on the beam size.
corpus_name: name of task input in the task context to read parses from.
allow_feature_weights: whether the op is expected to output weighted features.
If false, it will check that no weights are specified.
arg_prefix: prefix for context parameters.
continue_until_all_final: whether to continue parsing after the gold path falls
off the beam.
always_start_new_sentences: whether to skip to the beginning of a new sentence
after each training step.
)doc");
REGISTER_OP("BeamParser")
.Input("beam_state: int64")
.Input("transition_scores: float")
.Output("features: feature_size * string")
.Output("next_beam_state: int64")
.Output("alive: bool")
.Attr("feature_size: int")
.SetIsStateful()
.Doc(R"doc(
Updates the beam parser based on scores in the input transition scores.
beam_state: beam state.
transition_scores: scores for every transition from the current parser state.
features: features firing at the current parser state encoded as
dist_belief.SparseFeatures protocol buffers.
next_beam_state: beam state handle.
alive: whether the gold state is still in the beam.
feature_size: number of feature outputs emitted by this reader.
)doc");
REGISTER_OP("BeamParserOutput")
.Input("beam_state: int64")
.Output("indices_and_paths: int32")
.Output("batches_and_slots: int32")
.Output("gold_slot: int32")
.Output("path_scores: float")
.SetIsStateful()
.Doc(R"doc(
Converts the current state of the beam parser into a set of indices into
the scoring matrices that lead there.
beam_state: beam state handle.
indices_and_paths: matrix whose first row is a vector to look up beam paths and
decisions with, and whose second row are the corresponding
path ids.
batches_and_slots: matrix whose first row is a vector identifying the batch to
which the paths correspond, and whose second row are the
slots.
gold_slot: location in final beam of the gold path [batch_size].
path_scores: cumulative sum of scores along each path in each beam. Within each
beam, scores are sorted from low to high.
)doc");
REGISTER_OP("BeamEvalOutput")
.Input("beam_state: int64")
.Output("eval_metrics: int32")
.Output("documents: string")
.SetIsStateful()
.Doc(R"doc(
Computes eval metrics for the best paths in the input beams.
beam_state: beam state handle.
eval_metrics: token counts used to compute evaluation metrics.
documents: parsed documents.
)doc");
REGISTER_OP("LexiconBuilder")
.Attr("task_context: string")
.Attr("corpus_name: string='documents'")
.Attr("lexicon_max_prefix_length: int = 3")
.Attr("lexicon_max_suffix_length: int = 3")
.Doc(R"doc(
An op that collects term statistics over a corpus and saves a set of term maps.
task_context: file path at which to read the task context.
corpus_name: name of the context input to compute lexicons.
lexicon_max_prefix_length: maximum prefix length for lexicon words.
lexicon_max_suffix_length: maximum suffix length for lexicon words.
)doc");
REGISTER_OP("FeatureSize")
.Attr("task_context: string")
.Output("feature_sizes: int32")
.Output("domain_sizes: int32")
.Output("embedding_dims: int32")
.Output("num_actions: int32")
.Attr("arg_prefix: string='brain_parser'")
.Doc(R"doc(
An op that returns the number and domain sizes of parser features.
task_context: file path at which to read the task context.
feature_sizes: number of feature locators in each group of parser features.
domain_sizes: domain size for each feature group of parser features.
embedding_dims: embedding dimension for each feature group of parser features.
num_actions: number of actions a parser can perform.
arg_prefix: prefix for context parameters.
)doc");
REGISTER_OP("UnpackSparseFeatures")
.Input("sf: string")
.Output("indices: int32")
.Output("ids: int64")
.Output("weights: float")
.Doc(R"doc(
Converts a vector of strings with SparseFeatures to tensors.
Note that indices, ids and weights are vectors of the same size and have
one-to-one correspondence between their elements. ids and weights are each
obtained by sequentially concatenating sf[i].id and sf[i].weight, for i in
1...size(sf). Note that if sf[i].weight is not set, the default value for the
weight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were
extracted from sf[i], then index[j] is set to i.
sf: vector of string, where each element is the string encoding of
SpareFeatures proto.
indices: vector of indices inside sf
ids: vector of id extracted from the SparseFeatures proto.
weights: vector of weight extracted from the SparseFeatures proto.
)doc");
REGISTER_OP("WordEmbeddingInitializer")
.Output("word_embeddings: float")
.Attr("vectors: string")
.Attr("task_context: string")
.Attr("embedding_init: float = 1.0")
.Doc(R"doc(
Reads word embeddings from an sstable of dist_belief.TokenEmbedding protos for
every word specified in a text vocabulary file.
word_embeddings: a tensor containing word embeddings from the specified sstable.
vectors: path to recordio of word embedding vectors.
task_context: file path at which to read the task context.
)doc");
REGISTER_OP("DocumentSource")
.Output("documents: string")
.Output("last: bool")
.Attr("task_context: string")
.Attr("corpus_name: string='documents'")
.Attr("batch_size: int")
.SetIsStateful()
.Doc(R"doc(
Reads documents from documents_path and outputs them.
documents: a vector of documents as serialized protos.
last: whether this is the last batch of documents from this document path.
batch_size: how many documents to read at once.
)doc");
REGISTER_OP("DocumentSink")
.Input("documents: string")
.Attr("task_context: string")
.Attr("corpus_name: string='documents'")
.Doc(R"doc(
Write documents to documents_path.
documents: documents to write.
)doc");
REGISTER_OP("WellFormedFilter")
.Input("documents: string")
.Output("filtered: string")
.Attr("task_context: string")
.Attr("corpus_name: string='documents'")
.Attr("keep_malformed_documents: bool = False")
.Doc(R"doc(
)doc");
REGISTER_OP("ProjectivizeFilter")
.Input("documents: string")
.Output("filtered: string")
.Attr("task_context: string")
.Attr("corpus_name: string='documents'")
.Attr("discard_non_projective: bool = False")
.Doc(R"doc(
)doc");
} // namespace syntaxnet
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A program to annotate a conll file with a tensorflow neural net parser."""
import os
import os.path
import time
import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.python.platform import logging
from syntaxnet import sentence_pb2
from syntaxnet import graph_builder
from syntaxnet import structured_graph_builder
from syntaxnet.ops import gen_parser_ops
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('task_context', '',
'Path to a task context with inputs and parameters for '
'feature extractors.')
flags.DEFINE_string('model_path', '', 'Path to model parameters.')
flags.DEFINE_string('arg_prefix', None, 'Prefix for context parameters.')
flags.DEFINE_string('graph_builder', 'greedy',
'Which graph builder to use, either greedy or structured.')
flags.DEFINE_string('input', 'stdin',
'Name of the context input to read data from.')
flags.DEFINE_string('output', 'stdout',
'Name of the context input to write data to.')
flags.DEFINE_string('hidden_layer_sizes', '200,200',
'Comma separated list of hidden layer sizes.')
flags.DEFINE_integer('batch_size', 32,
'Number of sentences to process in parallel.')
flags.DEFINE_integer('beam_size', 8, 'Number of slots for beam parsing.')
flags.DEFINE_integer('max_steps', 1000, 'Max number of steps to take.')
flags.DEFINE_bool('slim_model', False,
'Whether to expect only averaged variables.')
def Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
"""Builds and evaluates a network.
Args:
sess: tensorflow session to use
num_actions: number of possible golden actions
feature_sizes: size of each feature vector
domain_sizes: number of possible feature ids in each feature vector
embedding_dims: embedding dimension for each feature group
"""
t = time.time()
hidden_layer_sizes = map(int, FLAGS.hidden_layer_sizes.split(','))
logging.info('Building training network with parameters: feature_sizes: %s '
'domain_sizes: %s', feature_sizes, domain_sizes)
if FLAGS.graph_builder == 'greedy':
parser = graph_builder.GreedyParser(num_actions,
feature_sizes,
domain_sizes,
embedding_dims,
hidden_layer_sizes,
gate_gradients=True,
arg_prefix=FLAGS.arg_prefix)
else:
parser = structured_graph_builder.StructuredGraphBuilder(
num_actions,
feature_sizes,
domain_sizes,
embedding_dims,
hidden_layer_sizes,
gate_gradients=True,
arg_prefix=FLAGS.arg_prefix,
beam_size=FLAGS.beam_size,
max_steps=FLAGS.max_steps)
task_context = FLAGS.task_context
parser.AddEvaluation(task_context,
FLAGS.batch_size,
corpus_name=FLAGS.input,
evaluation_max_steps=FLAGS.max_steps)
parser.AddSaver(FLAGS.slim_model)
sess.run(parser.inits.values())
parser.saver.restore(sess, FLAGS.model_path)
sink_documents = tf.placeholder(tf.string)
sink = gen_parser_ops.document_sink(sink_documents,
task_context=FLAGS.task_context,
corpus_name=FLAGS.output)
t = time.time()
num_epochs = None
num_tokens = 0
num_correct = 0
num_documents = 0
while True:
tf_eval_epochs, tf_eval_metrics, tf_documents = sess.run([
parser.evaluation['epochs'],
parser.evaluation['eval_metrics'],
parser.evaluation['documents'],
])
if len(tf_documents):
logging.info('Processed %d documents', len(tf_documents))
num_documents += len(tf_documents)
sess.run(sink, feed_dict={sink_documents: tf_documents})
num_tokens += tf_eval_metrics[0]
num_correct += tf_eval_metrics[1]
if num_epochs is None:
num_epochs = tf_eval_epochs
elif num_epochs < tf_eval_epochs:
break
logging.info('Total processed documents: %d', num_documents)
if num_tokens > 0:
eval_metric = 100.0 * num_correct / num_tokens
logging.info('num correct tokens: %d', num_correct)
logging.info('total tokens: %d', num_tokens)
logging.info('Seconds elapsed in evaluation: %.2f, '
'eval metric: %.2f%%', time.time() - t, eval_metric)
def main(unused_argv):
logging.set_verbosity(logging.INFO)
with tf.Session() as sess:
feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
gen_parser_ops.feature_size(task_context=FLAGS.task_context,
arg_prefix=FLAGS.arg_prefix))
with tf.Session() as sess:
Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims)
if __name__ == '__main__':
tf.app.run()
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/parser_features.h"
#include <string>
#include "syntaxnet/registry.h"
#include "syntaxnet/sentence_features.h"
#include "syntaxnet/workspace.h"
namespace syntaxnet {
// Registry for the parser feature functions.
REGISTER_CLASS_REGISTRY("parser feature function", ParserFeatureFunction);
// Registry for the parser state + token index feature functions.
REGISTER_CLASS_REGISTRY("parser+index feature function",
ParserIndexFeatureFunction);
RootFeatureType::RootFeatureType(const string &name,
const FeatureType &wrapped_type,
int root_value)
: FeatureType(name), wrapped_type_(wrapped_type), root_value_(root_value) {}
string RootFeatureType::GetFeatureValueName(FeatureValue value) const {
if (value == root_value_) return "<ROOT>";
return wrapped_type_.GetFeatureValueName(value);
}
FeatureValue RootFeatureType::GetDomainSize() const {
return wrapped_type_.GetDomainSize() + 1;
}
// Parser feature locator for accessing the remaining input tokens in the parser
// state. It takes the offset relative to the current input token as argument.
// Negative values represent tokens to the left, positive values to the right
// and 0 (the default argument value) represents the current input token.
class InputParserLocator : public ParserLocator<InputParserLocator> {
public:
// Gets the new focus.
int GetFocus(const WorkspaceSet &workspaces, const ParserState &state) const {
const int offset = argument();
return state.Input(offset);
}
};
REGISTER_PARSER_FEATURE_FUNCTION("input", InputParserLocator);
// Parser feature locator for accessing the stack in the parser state. The
// argument represents the position on the stack, 0 being the top of the stack.
class StackParserLocator : public ParserLocator<StackParserLocator> {
public:
// Gets the new focus.
int GetFocus(const WorkspaceSet &workspaces, const ParserState &state) const {
const int position = argument();
return state.Stack(position);
}
};
REGISTER_PARSER_FEATURE_FUNCTION("stack", StackParserLocator);
// Parser feature locator for locating the head of the focus token. The argument
// specifies the number of times the head function is applied. Please note that
// this operates on partially built dependency trees.
class HeadFeatureLocator : public ParserIndexLocator<HeadFeatureLocator> {
public:
// Updates the current focus to a new location. If the initial focus is
// outside the range of the sentence, returns -2.
void UpdateArgs(const WorkspaceSet &workspaces, const ParserState &state,
int *focus) const {
if (*focus < -1 || *focus >= state.sentence().token_size()) {
*focus = -2;
return;
}
const int levels = argument();
*focus = state.Parent(*focus, levels);
}
};
REGISTER_PARSER_IDX_FEATURE_FUNCTION("head", HeadFeatureLocator);
// Parser feature locator for locating children of the focus token. The argument
// specifies the number of times the leftmost (when the argument is < 0) or
// rightmost (when the argument > 0) child function is applied. Please note that
// this operates on partially built dependency trees.
class ChildFeatureLocator : public ParserIndexLocator<ChildFeatureLocator> {
public:
// Updates the current focus to a new location. If the initial focus is
// outside the range of the sentence, returns -2.
void UpdateArgs(const WorkspaceSet &workspaces, const ParserState &state,
int *focus) const {
if (*focus < -1 || *focus >= state.sentence().token_size()) {
*focus = -2;
return;
}
const int levels = argument();
if (levels < 0) {
*focus = state.LeftmostChild(*focus, -levels);
} else {
*focus = state.RightmostChild(*focus, levels);
}
}
};
REGISTER_PARSER_IDX_FEATURE_FUNCTION("child", ChildFeatureLocator);
// Parser feature locator for locating siblings of the focus token. The argument
// specifies the sibling position relative to the focus token: a negative value
// triggers a search to the left, while a positive value one to the right.
// Please note that this operates on partially built dependency trees.
class SiblingFeatureLocator
: public ParserIndexLocator<SiblingFeatureLocator> {
public:
// Updates the current focus to a new location. If the initial focus is
// outside the range of the sentence, returns -2.
void UpdateArgs(const WorkspaceSet &workspaces, const ParserState &state,
int *focus) const {
if (*focus < -1 || *focus >= state.sentence().token_size()) {
*focus = -2;
return;
}
const int position = argument();
if (position < 0) {
*focus = state.LeftSibling(*focus, -position);
} else {
*focus = state.RightSibling(*focus, position);
}
}
};
REGISTER_PARSER_IDX_FEATURE_FUNCTION("sibling", SiblingFeatureLocator);
// Feature function for computing the label from focus token. Note that this
// does not use the precomputed values, since we get the labels from the stack;
// the reason it utilizes sentence_features::Label is to obtain the label map.
class LabelFeatureFunction : public BasicParserSentenceFeatureFunction<Label> {
public:
// Computes the label of the relation between the focus token and its parent.
// Valid focus values range from -1 to sentence->size() - 1, inclusively.
FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
int focus, const FeatureVector *result) const override {
if (focus == -1) return RootValue();
if (focus < -1 || focus >= state.sentence().token_size()) {
return feature_.NumValues();
}
const int label = state.Label(focus);
return label == -1 ? RootValue() : label;
}
};
REGISTER_PARSER_IDX_FEATURE_FUNCTION("label", LabelFeatureFunction);
typedef BasicParserSentenceFeatureFunction<Word> WordFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("word", WordFeatureFunction);
typedef BasicParserSentenceFeatureFunction<Tag> TagFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("tag", TagFeatureFunction);
typedef BasicParserSentenceFeatureFunction<Digit> DigitFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("digit", DigitFeatureFunction);
typedef BasicParserSentenceFeatureFunction<Hyphen> HyphenFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("hyphen", HyphenFeatureFunction);
typedef BasicParserSentenceFeatureFunction<PrefixFeature> PrefixFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("prefix", PrefixFeatureFunction);
typedef BasicParserSentenceFeatureFunction<SuffixFeature> SuffixFeatureFunction;
REGISTER_PARSER_IDX_FEATURE_FUNCTION("suffix", SuffixFeatureFunction);
// Parser feature function that can use nested sentence feature functions for
// feature extraction.
class ParserTokenFeatureFunction : public NestedFeatureFunction<
FeatureFunction<Sentence, int>, ParserState, int> {
public:
void Preprocess(WorkspaceSet *workspaces, ParserState *state) const override {
for (auto *function : nested_) {
function->Preprocess(workspaces, state->mutable_sentence());
}
}
void Evaluate(const WorkspaceSet &workspaces, const ParserState &state,
int focus, FeatureVector *result) const override {
for (auto *function : nested_) {
function->Evaluate(workspaces, state.sentence(), focus, result);
}
}
// Returns the first nested feature's computed value.
FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
int focus, const FeatureVector *result) const override {
if (nested_.empty()) return -1;
return nested_[0]->Compute(workspaces, state.sentence(), focus, result);
}
};
REGISTER_PARSER_IDX_FEATURE_FUNCTION("token",
ParserTokenFeatureFunction);
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Sentence-based features for the transition parser.
#ifndef $TARGETDIR_PARSER_FEATURES_H_
#define $TARGETDIR_PARSER_FEATURES_H_
#include <string>
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/feature_types.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/workspace.h"
namespace syntaxnet {
// A union used to represent discrete and continuous feature values.
union FloatFeatureValue {
public:
explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {}
FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {}
FeatureValue discrete_value;
struct {
uint32 id;
float weight;
};
};
typedef FeatureFunction<ParserState> ParserFeatureFunction;
// Feature function for the transition parser based on a parser state object and
// a token index. This typically extracts information from a given token.
typedef FeatureFunction<ParserState, int> ParserIndexFeatureFunction;
// Utilities to register the two types of parser features.
#define REGISTER_PARSER_FEATURE_FUNCTION(name, component) \
REGISTER_FEATURE_FUNCTION(ParserFeatureFunction, name, component)
#define REGISTER_PARSER_IDX_FEATURE_FUNCTION(name, component) \
REGISTER_FEATURE_FUNCTION(ParserIndexFeatureFunction, name, component)
// Alias for locator type that takes a parser state, and produces a focus
// integer that can be used on nested ParserIndexFeature objects.
template<class DER>
using ParserLocator = FeatureAddFocusLocator<DER, ParserState, int>;
// Alias for Locator type features that take (ParserState, int) signatures and
// call other ParserIndexFeatures.
template<class DER>
using ParserIndexLocator = FeatureLocator<DER, ParserState, int>;
// Feature extractor for the transition parser based on a parser state object.
typedef FeatureExtractor<ParserState> ParserFeatureExtractor;
// A simple wrapper FeatureType that adds a special "<ROOT>" type.
class RootFeatureType : public FeatureType {
public:
// Creates a RootFeatureType that wraps a given type and adds the special
// "<ROOT>" value in root_value.
RootFeatureType(const string &name, const FeatureType &wrapped_type,
int root_value);
// Returns the feature value name, but with the special "<ROOT>" value.
string GetFeatureValueName(FeatureValue value) const override;
// Returns the original number of features plus one for the "<ROOT>" value.
FeatureValue GetDomainSize() const override;
private:
// A wrapped type that handles everything else besides "<ROOT>".
const FeatureType &wrapped_type_;
// The reserved root value.
int root_value_;
};
// Simple feature function that wraps a Sentence based feature
// function. It adds a "<ROOT>" feature value that is triggered whenever the
// focus is the special root token. This class is sub-classed based on the
// extracted arguments of the nested function.
template<class F>
class ParserSentenceFeatureFunction : public ParserIndexFeatureFunction {
public:
// Instantiates and sets up the nested feature.
void Setup(TaskContext *context) override {
this->feature_.set_descriptor(this->descriptor());
this->feature_.set_prefix(this->prefix());
this->feature_.set_extractor(this->extractor());
feature_.Setup(context);
}
// Initializes the nested feature and sets feature type.
void Init(TaskContext *context) override {
feature_.Init(context);
num_base_values_ = feature_.GetFeatureType()->GetDomainSize();
set_feature_type(new RootFeatureType(
name(), *feature_.GetFeatureType(), RootValue()));
}
// Passes workspace requests and preprocessing to the nested feature.
void RequestWorkspaces(WorkspaceRegistry *registry) override {
feature_.RequestWorkspaces(registry);
}
void Preprocess(WorkspaceSet *workspaces, ParserState *state) const override {
feature_.Preprocess(workspaces, state->mutable_sentence());
}
protected:
// Returns the special value to represent a root token.
FeatureValue RootValue() const { return num_base_values_; }
// Store the number of base values from the wrapped function so compute the
// root value.
int num_base_values_;
// The wrapped feature.
F feature_;
};
// Specialization of ParserSentenceFeatureFunction that calls the nested feature
// with (Sentence, int) arguments based on the current integer focus.
template<class F>
class BasicParserSentenceFeatureFunction :
public ParserSentenceFeatureFunction<F> {
public:
FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
int focus, const FeatureVector *result) const override {
if (focus == -1) return this->RootValue();
return this->feature_.Compute(workspaces, state.sentence(), focus, result);
}
};
} // namespace syntaxnet
#endif // $TARGETDIR_PARSER_FEATURES_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/parser_features.h"
#include <string>
#include "syntaxnet/utils.h"
#include "syntaxnet/feature_extractor.h"
#include "syntaxnet/parser_state.h"
#include "syntaxnet/populate_test_inputs.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/term_frequency_map.h"
#include "syntaxnet/workspace.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
// Feature extractor for the transition parser based on a parser state object.
typedef FeatureExtractor<ParserState, int> ParserIndexFeatureExtractor;
// Test fixture for parser features.
class ParserFeatureFunctionTest : public ::testing::Test {
protected:
// Sets up a parser state.
void SetUp() override {
// Prepare a document.
const char *kTaggedDocument =
"text: 'I saw a man with a telescope.' "
"token { word: 'I' start: 0 end: 0 tag: 'PRP' category: 'PRON'"
" break_level: NO_BREAK } "
"token { word: 'saw' start: 2 end: 4 tag: 'VBD' category: 'VERB'"
" break_level: SPACE_BREAK } "
"token { word: 'a' start: 6 end: 6 tag: 'DT' category: 'DET'"
" break_level: SPACE_BREAK } "
"token { word: 'man' start: 8 end: 10 tag: 'NN' category: 'NOUN'"
" break_level: SPACE_BREAK } "
"token { word: 'with' start: 12 end: 15 tag: 'IN' category: 'ADP'"
" break_level: SPACE_BREAK } "
"token { word: 'a' start: 17 end: 17 tag: 'DT' category: 'DET'"
" break_level: SPACE_BREAK } "
"token { word: 'telescope' start: 19 end: 27 tag: 'NN' category: 'NOUN'"
" break_level: SPACE_BREAK } "
"token { word: '.' start: 28 end: 28 tag: '.' category: '.'"
" break_level: NO_BREAK }";
CHECK(TextFormat::ParseFromString(kTaggedDocument, &sentence_));
creators_ = PopulateTestInputs::Defaults(sentence_);
// Prepare a label map. By adding labels in lexicographic order we make sure
// the term indices stay the same after sorting (which happens when the
// label map is saved to disk).
label_map_.Increment("NULL");
label_map_.Increment("ROOT");
label_map_.Increment("det");
label_map_.Increment("dobj");
label_map_.Increment("nsubj");
label_map_.Increment("p");
label_map_.Increment("pobj");
label_map_.Increment("prep");
creators_.Add("label-map", "text", "", [this](const string &filename) {
label_map_.Save(filename);
});
// Prepare a parser state.
state_.reset(new ParserState(&sentence_, nullptr /* no transition state */,
&label_map_));
}
// Prepares a feature for computations.
string ExtractFeature(const string &feature_name) {
context_.mutable_spec()->mutable_input()->Clear();
context_.mutable_spec()->mutable_output()->Clear();
feature_extractor_.reset(new ParserFeatureExtractor());
feature_extractor_->Parse(feature_name);
feature_extractor_->Setup(&context_);
creators_.Populate(&context_);
feature_extractor_->Init(&context_);
feature_extractor_->RequestWorkspaces(&registry_);
workspaces_.Reset(registry_);
feature_extractor_->Preprocess(&workspaces_, state_.get());
FeatureVector result;
feature_extractor_->ExtractFeatures(workspaces_, *state_, &result);
return result.type(0)->GetFeatureValueName(result.value(0));
}
std::unique_ptr<ParserState> state_;
Sentence sentence_;
WorkspaceSet workspaces_;
TermFrequencyMap label_map_;
PopulateTestInputs::CreatorMap creators_;
TaskContext context_;
WorkspaceRegistry registry_;
std::unique_ptr<ParserFeatureExtractor> feature_extractor_;
};
TEST_F(ParserFeatureFunctionTest, TagFeatureFunction) {
state_->Push(-1);
state_->Push(0);
EXPECT_EQ("PRP", ExtractFeature("input.tag"));
EXPECT_EQ("VBD", ExtractFeature("input(1).tag"));
EXPECT_EQ("<OUTSIDE>", ExtractFeature("input(10).tag"));
EXPECT_EQ("PRP", ExtractFeature("stack(0).tag"));
EXPECT_EQ("<ROOT>", ExtractFeature("stack(1).tag"));
}
TEST_F(ParserFeatureFunctionTest, LabelFeatureFunction) {
// Construct a partial dependency tree.
state_->AddArc(0, 1, 4);
state_->AddArc(1, -1, 1);
state_->AddArc(2, 3, 2);
state_->AddArc(3, 1, 3);
state_->AddArc(5, 6, 2);
state_->AddArc(6, 4, 6);
state_->AddArc(7, 1, 5);
// Test the feature function.
EXPECT_EQ(label_map_.GetTerm(4), ExtractFeature("input.label"));
EXPECT_EQ("ROOT", ExtractFeature("input(1).label"));
EXPECT_EQ(label_map_.GetTerm(2), ExtractFeature("input(2).label"));
// Push artifical root token onto the stack. This triggers the wrapped <ROOT>
// value, rather than indicating a token with the label "ROOT" (which may or
// may not be the artificial root token.)
state_->Push(-1);
EXPECT_EQ("<ROOT>", ExtractFeature("stack.label"));
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/parser_state.h"
#include "syntaxnet/utils.h"
#include "syntaxnet/kbest_syntax.pb.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/term_frequency_map.h"
namespace syntaxnet {
const char ParserState::kRootLabel[] = "ROOT";
ParserState::ParserState(Sentence *sentence,
ParserTransitionState *transition_state,
const TermFrequencyMap *label_map)
: sentence_(sentence),
num_tokens_(sentence->token_size()),
transition_state_(transition_state),
label_map_(label_map),
root_label_(kDefaultRootLabel),
next_(0) {
// Initialize the stack. Some transition systems could also push the
// artificial root on the stack, so we make room for that as well.
stack_.reserve(num_tokens_ + 1);
// Allocate space for head indices and labels. Initialize the head for all
// tokens to be the artificial root node, i.e. token -1.
head_.resize(num_tokens_, -1);
label_.resize(num_tokens_, RootLabel());
// Transition system-specific preprocessing.
if (transition_state_ != nullptr) transition_state_->Init(this);
}
ParserState::~ParserState() { delete transition_state_; }
ParserState *ParserState::Clone() const {
ParserState *new_state = new ParserState();
new_state->sentence_ = sentence_;
new_state->num_tokens_ = num_tokens_;
new_state->alternative_ = alternative_;
new_state->transition_state_ =
(transition_state_ == nullptr ? nullptr : transition_state_->Clone());
new_state->label_map_ = label_map_;
new_state->root_label_ = root_label_;
new_state->next_ = next_;
new_state->stack_.assign(stack_.begin(), stack_.end());
new_state->head_.assign(head_.begin(), head_.end());
new_state->label_.assign(label_.begin(), label_.end());
new_state->score_ = score_;
new_state->is_gold_ = is_gold_;
return new_state;
}
int ParserState::RootLabel() const { return root_label_; }
int ParserState::Next() const {
DCHECK_GE(next_, -1);
DCHECK_LE(next_, num_tokens_);
return next_;
}
int ParserState::Input(int offset) const {
int index = next_ + offset;
return index >= -1 && index < num_tokens_ ? index : -2;
}
void ParserState::Advance() {
DCHECK_LT(next_, num_tokens_);
++next_;
}
bool ParserState::EndOfInput() const { return next_ == num_tokens_; }
void ParserState::Push(int index) {
DCHECK_LE(stack_.size(), num_tokens_);
stack_.push_back(index);
}
int ParserState::Pop() {
DCHECK(!StackEmpty());
const int result = stack_.back();
stack_.pop_back();
return result;
}
int ParserState::Top() const {
DCHECK(!StackEmpty());
return stack_.back();
}
int ParserState::Stack(int position) const {
if (position < 0) return -2;
const int index = stack_.size() - 1 - position;
return (index < 0) ? -2 : stack_[index];
}
int ParserState::StackSize() const { return stack_.size(); }
bool ParserState::StackEmpty() const { return stack_.empty(); }
int ParserState::Head(int index) const {
DCHECK_GE(index, -1);
DCHECK_LT(index, num_tokens_);
return index == -1 ? -1 : head_[index];
}
int ParserState::Label(int index) const {
DCHECK_GE(index, -1);
DCHECK_LT(index, num_tokens_);
return index == -1 ? RootLabel() : label_[index];
}
int ParserState::Parent(int index, int n) const {
// Find the n-th parent by applying the head function n times.
DCHECK_GE(index, -1);
DCHECK_LT(index, num_tokens_);
while (n-- > 0) index = Head(index);
return index;
}
int ParserState::LeftmostChild(int index, int n) const {
DCHECK_GE(index, -1);
DCHECK_LT(index, num_tokens_);
while (n-- > 0) {
// Find the leftmost child by scanning from start until a child is
// encountered.
int i;
for (i = -1; i < index; ++i) {
if (Head(i) == index) break;
}
if (i == index) return -2;
index = i;
}
return index;
}
int ParserState::RightmostChild(int index, int n) const {
DCHECK_GE(index, -1);
DCHECK_LT(index, num_tokens_);
while (n-- > 0) {
// Find the rightmost child by scanning backward from end until a child
// is encountered.
int i;
for (i = num_tokens_ - 1; i > index; --i) {
if (Head(i) == index) break;
}
if (i == index) return -2;
index = i;
}
return index;
}
int ParserState::LeftSibling(int index, int n) const {
// Find the n-th left sibling by scanning left until the n-th child of the
// parent is encountered.
DCHECK_GE(index, -1);
DCHECK_LT(index, num_tokens_);
if (index == -1 && n > 0) return -2;
int i = index;
while (n > 0) {
--i;
if (i == -1) return -2;
if (Head(i) == Head(index)) --n;
}
return i;
}
int ParserState::RightSibling(int index, int n) const {
// Find the n-th right sibling by scanning right until the n-th child of the
// parent is encountered.
DCHECK_GE(index, -1);
DCHECK_LT(index, num_tokens_);
if (index == -1 && n > 0) return -2;
int i = index;
while (n > 0) {
++i;
if (i == num_tokens_) return -2;
if (Head(i) == Head(index)) --n;
}
return i;
}
void ParserState::AddArc(int index, int head, int label) {
DCHECK_GE(index, 0);
DCHECK_LT(index, num_tokens_);
head_[index] = head;
label_[index] = label;
}
int ParserState::GoldHead(int index) const {
// A valid ParserState index is transformed to a valid Sentence index,
// then the gold head is extracted.
DCHECK_GE(index, -1);
DCHECK_LT(index, num_tokens_);
if (index == -1) return -1;
const int offset = 0;
const int gold_head = GetToken(index).head();
return gold_head == -1 ? -1 : gold_head - offset;
}
int ParserState::GoldLabel(int index) const {
// A valid ParserState index is transformed to a valid Sentence index,
// then the gold label is extracted.
DCHECK_GE(index, -1);
DCHECK_LT(index, num_tokens_);
if (index == -1) return RootLabel();
string gold_label;
gold_label = GetToken(index).label();
return label_map_->LookupIndex(gold_label, RootLabel() /* unknown */);
}
void ParserState::AddParseToDocument(Sentence *sentence,
bool rewrite_root_labels) const {
transition_state_->AddParseToDocument(*this, rewrite_root_labels, sentence);
}
bool ParserState::IsTokenCorrect(int index) const {
return transition_state_->IsTokenCorrect(*this, index);
}
string ParserState::LabelAsString(int label) const {
if (label == root_label_) return "ROOT";
if (label >= 0 && label < label_map_->Size()) {
return label_map_->GetTerm(label);
}
return "";
}
string ParserState::ToString() const {
return transition_state_->ToString(*this);
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Parser state for the transition-based dependency parser.
#ifndef $TARGETDIR_PARSER_STATE_H_
#define $TARGETDIR_PARSER_STATE_H_
#include <string>
#include <vector>
#include "syntaxnet/utils.h"
#include "syntaxnet/kbest_syntax.pb.h"
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/sentence.pb.h"
namespace syntaxnet {
class TermFrequencyMap;
// A ParserState object represents the state of the parser during the parsing of
// a sentence. The state consists of a pointer to the next input token and a
// stack of partially processed tokens. The parser state can be changed by
// applying a sequence of transitions. Some transitions also add relations
// to the dependency tree of the sentence. The parser state also records the
// (partial) parse tree for the sentence by recording the head of each token and
// the label of this relation. The state is used for both training and parsing.
class ParserState {
public:
// String representation of the root label.
static const char kRootLabel[];
// Default value for the root label in case it's not in the label map.
static const int kDefaultRootLabel = -1;
// Initializes the parser state for a sentence, using an additional transition
// state for preprocessing and/or additional information specific to the
// transition system. The transition state is allowed to be null, in which
// case no additional work is performed. The parser state takes ownership of
// the transition state. A label map is used for transforming between integer
// and string representations of the labels.
ParserState(Sentence *sentence,
ParserTransitionState *transition_state,
const TermFrequencyMap *label_map);
// Deletes the parser state.
~ParserState();
// Clones the parser state.
ParserState *Clone() const;
// Returns the root label.
int RootLabel() const;
// Returns the index of the next input token.
int Next() const;
// Returns the number of tokens in the sentence.
int NumTokens() const { return num_tokens_; }
// Returns the token index relative to the next input token. If no such token
// exists, returns -2.
int Input(int offset) const;
// Advances to the next input token.
void Advance();
// Returns true if all tokens have been processed.
bool EndOfInput() const;
// Pushes an element to the stack.
void Push(int index);
// Pops the top element from stack and returns it.
int Pop();
// Returns the element from the top of the stack.
int Top() const;
// Returns the element at a certain position in the stack. Stack(0) is the top
// stack element. If no such position exists, returns -2.
int Stack(int position) const;
// Returns the number of elements on the stack.
int StackSize() const;
// Returns true if the stack is empty.
bool StackEmpty() const;
// Returns the head index for a given token.
int Head(int index) const;
// Returns the label of the relation to head for a given token.
int Label(int index) const;
// Returns the parent of a given token 'n' levels up in the tree.
int Parent(int index, int n) const;
// Returns the leftmost child of a given token 'n' levels down in the tree. If
// no such child exists, returns -2.
int LeftmostChild(int index, int n) const;
// Returns the rightmost child of a given token 'n' levels down in the tree.
// If no such child exists, returns -2.
int RightmostChild(int index, int n) const;
// Returns the n-th left sibling of a given token. If no such sibling exists,
// returns -2.
int LeftSibling(int index, int n) const;
// Returns the n-th right sibling of a given token. If no such sibling exists,
// returns -2.
int RightSibling(int index, int n) const;
// Adds an arc to the partial dependency tree of the state.
void AddArc(int index, int head, int label);
// Returns the gold head index for a given token, based on the underlying
// annotated sentence.
int GoldHead(int index) const;
// Returns the gold label for a given token, based on the underlying annotated
// sentence.
int GoldLabel(int index) const;
// Get a reference to the underlying token at index. Returns an empty default
// Token if accessing the root.
const Token &GetToken(int index) const {
if (index == -1) return kRootToken;
return sentence().token(index);
}
// Annotates a document with the dependency relations built during parsing for
// one of its sentences. If rewrite_root_labels is true, then all tokens with
// no heads will be assigned the default root label "ROOT".
void AddParseToDocument(Sentence *document, bool rewrite_root_labels) const;
// As above, but uses the default of rewrite_root_labels = true.
void AddParseToDocument(Sentence *document) const {
AddParseToDocument(document, true);
}
// Whether a parsed token should be considered correct for evaluation.
bool IsTokenCorrect(int index) const;
// Returns the string representation of a dependency label, or an empty string
// if the label is invalid.
string LabelAsString(int label) const;
// Returns a string representation of the parser state.
string ToString() const;
// Returns the underlying sentence instance.
const Sentence &sentence() const { return *sentence_; }
Sentence *mutable_sentence() const { return sentence_; }
// Returns the transition system-specific state.
const ParserTransitionState *transition_state() const {
return transition_state_;
}
ParserTransitionState *mutable_transition_state() {
return transition_state_;
}
// Gets/sets the flag which says that the state was obtained though gold
// transitions only.
bool is_gold() const { return is_gold_; }
void set_is_gold(bool is_gold) { is_gold_ = is_gold; }
private:
// Empty constructor used for the cloning operation.
ParserState() {}
// Default value for the root token.
const Token kRootToken;
// Sentence to parse. Not owned.
Sentence *sentence_ = nullptr;
// Number of tokens in the sentence to parse.
int num_tokens_;
// Which alternative token analysis is used for tag/category/head/label
// information. -1 means use default.
int alternative_ = -1;
// Transition system-specific state. Owned.
ParserTransitionState *transition_state_ = nullptr;
// Label map used for conversions between integer and string representations
// of the dependency labels. Not owned.
const TermFrequencyMap *label_map_ = nullptr;
// Root label.
int root_label_;
// Index of the next input token.
int next_;
// Parse stack of partially processed tokens.
vector<int> stack_;
// List of head positions for the (partial) dependency tree.
vector<int> head_;
// List of dependency relation labels describing the (partial) dependency
// tree.
vector<int> label_;
// Score of the parser state.
double score_ = 0.0;
// True if this is the gold standard sequence (used for structured learning).
bool is_gold_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(ParserState);
};
} // namespace syntaxnet
#endif // $TARGETDIR_PARSER_STATE_H_
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A program to train a tensorflow neural net parser from a a conll file."""
import os
import os.path
import time
import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.python.platform import logging
from google.protobuf import text_format
from syntaxnet import graph_builder
from syntaxnet import structured_graph_builder
from syntaxnet.ops import gen_parser_ops
from syntaxnet import task_spec_pb2
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('tf_master', '',
'TensorFlow execution engine to connect to.')
flags.DEFINE_string('output_path', '', 'Top level for output.')
flags.DEFINE_string('task_context', '',
'Path to a task context with resource locations and '
'parameters.')
flags.DEFINE_string('arg_prefix', None, 'Prefix for context parameters.')
flags.DEFINE_string('params', '0', 'Unique identifier of parameter grid point.')
flags.DEFINE_string('training_corpus', 'training-corpus',
'Name of the context input to read training data from.')
flags.DEFINE_string('tuning_corpus', 'tuning-corpus',
'Name of the context input to read tuning data from.')
flags.DEFINE_string('word_embeddings', None,
'Recordio containing pretrained word embeddings, will be '
'loaded as the first embedding matrix.')
flags.DEFINE_bool('compute_lexicon', False, '')
flags.DEFINE_bool('projectivize_training_set', False, '')
flags.DEFINE_string('hidden_layer_sizes', '200,200',
'Comma separated list of hidden layer sizes.')
flags.DEFINE_string('graph_builder', 'greedy',
'Graph builder to use, either "greedy" or "structured".')
flags.DEFINE_integer('batch_size', 32,
'Number of sentences to process in parallel.')
flags.DEFINE_integer('beam_size', 10, 'Number of slots for beam parsing.')
flags.DEFINE_integer('num_epochs', 10, 'Number of epochs to train for.')
flags.DEFINE_integer('max_steps', 50,
'Max number of parser steps during a training step.')
flags.DEFINE_integer('report_every', 100,
'Report cost and training accuracy every this many steps.')
flags.DEFINE_integer('checkpoint_every', 5000,
'Measure tuning UAS and checkpoint every this many steps.')
flags.DEFINE_bool('slim_model', False,
'Whether to remove non-averaged variables, for compactness.')
flags.DEFINE_float('learning_rate', 0.1, 'Initial learning rate parameter.')
flags.DEFINE_integer('decay_steps', 4000,
'Decay learning rate by 0.96 every this many steps.')
flags.DEFINE_float('momentum', 0.9,
'Momentum parameter for momentum optimizer.')
flags.DEFINE_string('seed', '0', 'Initialization seed for TF variables.')
flags.DEFINE_string('pretrained_params', None,
'Path to model from which to load params.')
flags.DEFINE_string('pretrained_params_names', None,
'List of names of tensors to load from pretrained model.')
flags.DEFINE_float('averaging_decay', 0.9999,
'Decay for exponential moving average when computing'
'averaged parameters, set to 1 to do vanilla averaging.')
def StageName():
return os.path.join(FLAGS.arg_prefix, FLAGS.graph_builder)
def OutputPath(path):
return os.path.join(FLAGS.output_path, StageName(), FLAGS.params, path)
def RewriteContext():
context = task_spec_pb2.TaskSpec()
with gfile.FastGFile(FLAGS.task_context) as fin:
text_format.Merge(fin.read(), context)
for resource in context.input:
if resource.creator == StageName():
del resource.part[:]
part = resource.part.add()
part.file_pattern = os.path.join(OutputPath(resource.name))
with gfile.FastGFile(OutputPath('context'), 'w') as fout:
fout.write(str(context))
def WriteStatus(num_steps, eval_metric, best_eval_metric):
status = os.path.join(os.getenv('GOOGLE_STATUS_DIR') or '/tmp', 'STATUS')
message = ('Parameters: %s | Steps: %d | Tuning score: %.2f%% | '
'Best tuning score: %.2f%%' % (FLAGS.params, num_steps,
eval_metric, best_eval_metric))
with gfile.FastGFile(status, 'w') as fout:
fout.write(message)
with gfile.FastGFile(OutputPath('status'), 'a') as fout:
fout.write(message + '\n')
def Eval(sess, parser, num_steps, best_eval_metric):
"""Evaluates a network and checkpoints it to disk.
Args:
sess: tensorflow session to use
parser: graph builder containing all ops references
num_steps: number of training steps taken, for logging
best_eval_metric: current best eval metric, to decide whether this model is
the best so far
Returns:
new best eval metric
"""
logging.info('Evaluating training network.')
t = time.time()
num_epochs = None
num_tokens = 0
num_correct = 0
while True:
tf_eval_epochs, tf_eval_metrics = sess.run([
parser.evaluation['epochs'], parser.evaluation['eval_metrics']
])
num_tokens += tf_eval_metrics[0]
num_correct += tf_eval_metrics[1]
if num_epochs is None:
num_epochs = tf_eval_epochs
elif num_epochs < tf_eval_epochs:
break
eval_metric = 0 if num_tokens == 0 else (100.0 * num_correct / num_tokens)
logging.info('Seconds elapsed in evaluation: %.2f, '
'eval metric: %.2f%%', time.time() - t, eval_metric)
WriteStatus(num_steps, eval_metric, max(eval_metric, best_eval_metric))
# Save parameters.
if FLAGS.output_path:
logging.info('Writing out trained parameters.')
parser.saver.save(sess, OutputPath('latest-model'))
if eval_metric > best_eval_metric:
parser.saver.save(sess, OutputPath('model'))
return max(eval_metric, best_eval_metric)
def Train(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
"""Builds and trains the network.
Args:
sess: tensorflow session to use.
num_actions: number of possible golden actions.
feature_sizes: size of each feature vector.
domain_sizes: number of possible feature ids in each feature vector.
embedding_dims: embedding dimension to use for each feature group.
"""
t = time.time()
hidden_layer_sizes = map(int, FLAGS.hidden_layer_sizes.split(','))
logging.info('Building training network with parameters: feature_sizes: %s '
'domain_sizes: %s', feature_sizes, domain_sizes)
if FLAGS.graph_builder == 'greedy':
parser = graph_builder.GreedyParser(num_actions,
feature_sizes,
domain_sizes,
embedding_dims,
hidden_layer_sizes,
seed=int(FLAGS.seed),
gate_gradients=True,
averaging_decay=FLAGS.averaging_decay,
arg_prefix=FLAGS.arg_prefix)
else:
parser = structured_graph_builder.StructuredGraphBuilder(
num_actions,
feature_sizes,
domain_sizes,
embedding_dims,
hidden_layer_sizes,
seed=int(FLAGS.seed),
gate_gradients=True,
averaging_decay=FLAGS.averaging_decay,
arg_prefix=FLAGS.arg_prefix,
beam_size=FLAGS.beam_size,
max_steps=FLAGS.max_steps)
task_context = OutputPath('context')
if FLAGS.word_embeddings is not None:
parser.AddPretrainedEmbeddings(0, FLAGS.word_embeddings, task_context)
corpus_name = ('projectivized-training-corpus' if
FLAGS.projectivize_training_set else FLAGS.training_corpus)
parser.AddTraining(task_context,
FLAGS.batch_size,
learning_rate=FLAGS.learning_rate,
momentum=FLAGS.momentum,
decay_steps=FLAGS.decay_steps,
corpus_name=corpus_name)
parser.AddEvaluation(task_context,
FLAGS.batch_size,
corpus_name=FLAGS.tuning_corpus)
parser.AddSaver(FLAGS.slim_model)
# Save graph.
if FLAGS.output_path:
with gfile.FastGFile(OutputPath('graph'), 'w') as f:
f.write(sess.graph_def.SerializeToString())
logging.info('Initializing...')
num_epochs = 0
cost_sum = 0.0
num_steps = 0
best_eval_metric = 0.0
sess.run(parser.inits.values())
if FLAGS.pretrained_params is not None:
logging.info('Loading pretrained params from %s', FLAGS.pretrained_params)
feed_dict = {'save/Const:0': FLAGS.pretrained_params}
targets = []
for node in sess.graph_def.node:
if (node.name.startswith('save/Assign') and
node.input[0] in FLAGS.pretrained_params_names.split(',')):
logging.info('Loading %s with op %s', node.input[0], node.name)
targets.append(node.name)
sess.run(targets, feed_dict=feed_dict)
logging.info('Training...')
while num_epochs < FLAGS.num_epochs:
tf_epochs, tf_cost, _ = sess.run([parser.training[
'epochs'], parser.training['cost'], parser.training['train_op']])
num_epochs = tf_epochs
num_steps += 1
cost_sum += tf_cost
if num_steps % FLAGS.report_every == 0:
logging.info('Epochs: %d, num steps: %d, '
'seconds elapsed: %.2f, avg cost: %.2f, ', num_epochs,
num_steps, time.time() - t, cost_sum / FLAGS.report_every)
cost_sum = 0.0
if num_steps % FLAGS.checkpoint_every == 0:
best_eval_metric = Eval(sess, parser, num_steps, best_eval_metric)
def main(unused_argv):
logging.set_verbosity(logging.INFO)
if not gfile.IsDirectory(OutputPath('')):
gfile.MakeDirs(OutputPath(''))
# Rewrite context.
RewriteContext()
# Creates necessary term maps.
if FLAGS.compute_lexicon:
logging.info('Computing lexicon...')
with tf.Session(FLAGS.tf_master) as sess:
gen_parser_ops.lexicon_builder(task_context=OutputPath('context'),
corpus_name=FLAGS.training_corpus).run()
with tf.Session(FLAGS.tf_master) as sess:
feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
gen_parser_ops.feature_size(task_context=OutputPath('context'),
arg_prefix=FLAGS.arg_prefix))
# Well formed and projectivize.
if FLAGS.projectivize_training_set:
logging.info('Preprocessing...')
with tf.Session(FLAGS.tf_master) as sess:
source, last = gen_parser_ops.document_source(
task_context=OutputPath('context'),
batch_size=FLAGS.batch_size,
corpus_name=FLAGS.training_corpus)
sink = gen_parser_ops.document_sink(
task_context=OutputPath('context'),
corpus_name='projectivized-training-corpus',
documents=gen_parser_ops.projectivize_filter(
gen_parser_ops.well_formed_filter(source,
task_context=OutputPath(
'context')),
task_context=OutputPath('context')))
while True:
tf_last, _ = sess.run([last, sink])
if tf_last:
break
logging.info('Training...')
with tf.Session(FLAGS.tf_master) as sess:
Train(sess, num_actions, feature_sizes, domain_sizes, embedding_dims)
if __name__ == '__main__':
tf.app.run()
#!/bin/bash
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# This test trains a parser on a small dataset, then runs it in greedy mode and
# in structured mode with beam 1, and checks that the result is identical.
set -eux
BINDIR=$TEST_SRCDIR/syntaxnet
CONTEXT=$BINDIR/testdata/context.pbtxt
TMP_DIR=/tmp/syntaxnet-output
mkdir -p $TMP_DIR
sed "s=SRCDIR=$TEST_SRCDIR=" "$CONTEXT" | \
sed "s=OUTPATH=$TMP_DIR=" > $TMP_DIR/context
PARAMS=128-0.08-3600-0.9-0
"$BINDIR/parser_trainer" \
--arg_prefix=brain_parser \
--batch_size=32 \
--compute_lexicon \
--decay_steps=3600 \
--graph_builder=greedy \
--hidden_layer_sizes=128 \
--learning_rate=0.08 \
--momentum=0.9 \
--output_path=$TMP_DIR \
--task_context=$TMP_DIR/context \
--training_corpus=training-corpus \
--tuning_corpus=tuning-corpus \
--params=$PARAMS \
--num_epochs=12 \
--report_every=100 \
--checkpoint_every=1000 \
--logtostderr
"$BINDIR/parser_eval" \
--task_context=$TMP_DIR/brain_parser/greedy/$PARAMS/context \
--hidden_layer_sizes=128 \
--input=tuning-corpus \
--output=stdout \
--arg_prefix=brain_parser \
--graph_builder=greedy \
--model_path=$TMP_DIR/brain_parser/greedy/$PARAMS/model \
--logtostderr \
> $TMP_DIR/greedy-out
"$BINDIR/parser_eval" \
--task_context=$TMP_DIR/context \
--hidden_layer_sizes=128 \
--beam_size=1 \
--input=tuning-corpus \
--output=stdout \
--arg_prefix=brain_parser \
--graph_builder=structured \
--model_path=$TMP_DIR/brain_parser/greedy/$PARAMS/model \
--logtostderr \
> $TMP_DIR/struct-beam1-out
diff $TMP_DIR/greedy-out $TMP_DIR/struct-beam1-out
STRUCT_PARAMS=128-0.001-3600-0.9-0
"$BINDIR/parser_trainer" \
--arg_prefix=brain_parser \
--batch_size=8 \
--compute_lexicon \
--decay_steps=3600 \
--graph_builder=structured \
--hidden_layer_sizes=128 \
--learning_rate=0.001 \
--momentum=0.9 \
--pretrained_params=$TMP_DIR/brain_parser/greedy/$PARAMS/model \
--pretrained_params_names=\
embedding_matrix_0,embedding_matrix_1,embedding_matrix_2,bias_0,weights_0 \
--output_path=$TMP_DIR \
--task_context=$TMP_DIR/context \
--training_corpus=training-corpus \
--tuning_corpus=tuning-corpus \
--params=$STRUCT_PARAMS \
--num_epochs=20 \
--report_every=25 \
--checkpoint_every=200 \
--logtostderr
"$BINDIR/parser_eval" \
--task_context=$TMP_DIR/context \
--hidden_layer_sizes=128 \
--beam_size=8 \
--input=tuning-corpus \
--output=stdout \
--arg_prefix=brain_parser \
--graph_builder=structured \
--model_path=$TMP_DIR/brain_parser/structured/$STRUCT_PARAMS/model \
--logtostderr \
> $TMP_DIR/struct-beam8-out
echo "PASS"
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/parser_transitions.h"
#include "syntaxnet/parser_state.h"
namespace syntaxnet {
// Transition system registry.
REGISTER_CLASS_REGISTRY("transition system", ParserTransitionSystem);
void ParserTransitionSystem::PerformAction(ParserAction action,
ParserState *state) const {
PerformActionWithoutHistory(action, state);
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Transition system for the transition-based dependency parser.
#ifndef $TARGETDIR_PARSER_TRANSITIONS_H_
#define $TARGETDIR_PARSER_TRANSITIONS_H_
#include <string>
#include <vector>
#include "syntaxnet/utils.h"
#include "syntaxnet/registry.h"
namespace tensorflow {
namespace io {
class RecordReader;
class RecordWriter;
}
}
namespace syntaxnet {
class Sentence;
class ParserState;
class TaskContext;
// Parser actions for the transition system are encoded as integers.
typedef int ParserAction;
// Label type for the parser action.
enum class LabelType {
NO_LABEL = 0,
LEFT_LABEL = 1,
RIGHT_LABEL = 2,
};
// Transition system-specific state. Transition systems can subclass this to
// preprocess the parser state and/or to keep additional information during
// parsing.
class ParserTransitionState {
public:
virtual ~ParserTransitionState() {}
// Clones the transition state.
virtual ParserTransitionState *Clone() const = 0;
// Initializes a parser state for the transition system.
virtual void Init(ParserState *state) = 0;
virtual void AddParseToDocument(const ParserState &state,
bool rewrite_root_labels,
Sentence *sentence) const {}
// Whether a parsed token should be considered correct for evaluation.
virtual bool IsTokenCorrect(const ParserState &state, int index) const = 0;
// Returns a human readable string representation of this state.
virtual string ToString(const ParserState &state) const = 0;
};
// A transition system is used for handling the parser state transitions. During
// training the transition system is used for extracting a canonical sequence of
// transitions for an annotated sentence. During parsing the transition system
// is used for applying the predicted transitions to the parse state and
// therefore build the parse tree for the sentence. Transition systems can be
// implemented by subclassing this abstract class and registered using the
// REGISTER_TRANSITION_SYSTEM macro.
class ParserTransitionSystem
: public RegisterableClass<ParserTransitionSystem> {
public:
// Construction and cleanup.
ParserTransitionSystem() {}
virtual ~ParserTransitionSystem() {}
// Sets up the transition system. If inputs are needed, this is the place to
// specify them.
virtual void Setup(TaskContext *context) {}
// Initializes the transition system.
virtual void Init(TaskContext *context) {}
// Reads the transition system from disk.
virtual void Read(tensorflow::io::RecordReader *reader) {}
// Writes the transition system to disk.
virtual void Write(tensorflow::io::RecordWriter *writer) const {}
// Returns the number of action types.
virtual int NumActionTypes() const = 0;
// Returns the number of actions.
virtual int NumActions(int num_labels) const = 0;
// Internally creates the set of outcomes (when transition systems support a
// variable number of actions).
virtual void CreateOutcomeSet(int num_labels) {}
// Returns the default action for a given state.
virtual ParserAction GetDefaultAction(const ParserState &state) const = 0;
// Returns the next gold action for the parser during training using the
// dependency relations found in the underlying annotated sentence.
virtual ParserAction GetNextGoldAction(const ParserState &state) const = 0;
// Returns all next gold actions for the parser during training using the
// dependency relations found in the underlying annotated sentence.
virtual void GetAllNextGoldActions(const ParserState &state,
vector<ParserAction> *actions) const {
ParserAction action = GetNextGoldAction(state);
*actions = {action};
}
// Internally counts all next gold actions from the current parser state.
virtual void CountAllNextGoldActions(const ParserState &state) {}
// Returns the number of atomic actions within the specified ParserAction.
virtual int ActionLength(ParserAction action) const { return 1; }
// Returns true if the action is allowed in the given parser state.
virtual bool IsAllowedAction(ParserAction action,
const ParserState &state) const = 0;
// Performs the specified action on a given parser state. The action is not
// saved in the state's history.
virtual void PerformActionWithoutHistory(ParserAction action,
ParserState *state) const = 0;
// Performs the specified action on a given parser state. The action is saved
// in the state's history.
void PerformAction(ParserAction action, ParserState *state) const;
// Returns true if a given state is deterministic.
virtual bool IsDeterministicState(const ParserState &state) const = 0;
// Returns true if no more actions can be applied to a given parser state.
virtual bool IsFinalState(const ParserState &state) const = 0;
// Returns a string representation of a parser action.
virtual string ActionAsString(ParserAction action,
const ParserState &state) const = 0;
// Returns a new transition state that can be used to put additional
// information in a parser state. By specifying if we are in training_mode
// (true) or not (false), we can construct a different transition state
// depending on whether we are training a model or parsing new documents. A
// null return value means we don't need to add anything to the parser state.
virtual ParserTransitionState *NewTransitionState(bool training_mode) const {
return nullptr;
}
// Whether to back off to the best allowable transition rather than the
// default action when the highest scoring action is not allowed. Some
// transition systems do not degrade gracefully to the default action and so
// should return true for this function.
virtual bool BackOffToBestAllowableTransition() const { return false; }
// Whether the system returns multiple gold transitions from a single
// configuration.
virtual bool ReturnsMultipleGoldTransitions() const { return false; }
// Whether the system allows non-projective trees.
virtual bool AllowsNonProjective() const { return false; }
// Action meta data: get pointers to token indices based on meta-info about
// (state, action) pairs. NOTE: the following interface is somewhat
// experimental and may be subject to change. Use with caution and ask
// djweiss@ for details.
// Whether or not the system supports computing meta-data about actions.
virtual bool SupportsActionMetaData() const { return false; }
// Get the index of the child that would be created by this action. -1 for
// no child created.
virtual int ChildIndex(const ParserState &state,
const ParserAction &action) const {
return -1;
}
// Get the index of the parent that would gain a new child by this action. -1
// for no parent modified.
virtual int ParentIndex(const ParserState &state,
const ParserAction &action) const {
return -1;
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(ParserTransitionSystem);
};
#define REGISTER_TRANSITION_SYSTEM(type, component) \
REGISTER_CLASS_COMPONENT(ParserTransitionSystem, type, component)
} // namespace syntaxnet
#endif // $TARGETDIR_PARSER_TRANSITIONS_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/populate_test_inputs.h"
#include <map>
#include <utility>
#include "gtest/gtest.h"
#include "syntaxnet/utils.h"
#include "syntaxnet/dictionary.pb.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/task_context.h"
#include "syntaxnet/task_spec.pb.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
void PopulateTestInputs::CreatorMap::Add(
const string &name, const string &file_format, const string &record_format,
PopulateTestInputs::CreateFile makefile) {
(*this)[name] = [name, file_format, record_format,
makefile](TaskInput *input) {
makefile(AddPart(input, file_format, record_format));
};
}
bool PopulateTestInputs::CreatorMap::Populate(TaskContext *context) const {
return PopulateTestInputs::Populate(*this, context);
}
PopulateTestInputs::CreatorMap PopulateTestInputs::Defaults(
const Sentence &document) {
CreatorMap creators;
creators["category-map"] =
CreateTFMapFromDocumentTokens(document, TokenCategory);
creators["label-map"] = CreateTFMapFromDocumentTokens(document, TokenLabel);
creators["tag-map"] = CreateTFMapFromDocumentTokens(document, TokenTag);
creators["tag-to-category"] = CreateTagToCategoryFromTokens(document);
creators["word-map"] = CreateTFMapFromDocumentTokens(document, TokenWord);
return creators;
}
bool PopulateTestInputs::Populate(
const std::unordered_map<string, Create> &creator_map,
TaskContext *context) {
TaskSpec *spec = context->mutable_spec();
bool found_all_inputs = true;
// Fail if a mandatory input is not found.
auto name_not_found = [&found_all_inputs](TaskInput *input) {
found_all_inputs = false;
};
for (TaskInput &input : *spec->mutable_input()) {
auto it = creator_map.find(input.name());
(it == creator_map.end() ? name_not_found : it->second)(&input);
// Check for compatibility with declared supported formats.
for (const auto &part : input.part()) {
if (!TaskContext::Supports(input, part.file_format(),
part.record_format())) {
LOG(FATAL) << "Input " << input.name()
<< " does not support file of type " << part.file_format()
<< "/" << part.record_format();
}
}
}
return found_all_inputs;
}
PopulateTestInputs::Create PopulateTestInputs::CreateTFMapFromDocumentTokens(
const Sentence &document,
std::function<vector<string>(const Token &)> token2str) {
return [document, token2str](TaskInput *input) {
TermFrequencyMap map;
// Build and write the dummy term frequency map.
for (const Token &token : document.token()) {
vector<string> strings_for_token = token2str(token);
for (const string &s : strings_for_token) map.Increment(s);
}
string file_name = AddPart(input, "text", "");
map.Save(file_name);
};
}
PopulateTestInputs::Create PopulateTestInputs::CreateTagToCategoryFromTokens(
const Sentence &document) {
return [document](TaskInput *input) {
TagToCategoryMap tag_to_category;
for (auto &token : document.token()) {
if (token.has_tag()) {
tag_to_category.SetCategory(token.tag(), token.category());
}
}
const string file_name = AddPart(input, "text", "");
tag_to_category.Save(file_name);
};
}
vector<string> PopulateTestInputs::TokenCategory(const Token &token) {
if (token.has_category()) return {token.category()};
return {};
}
vector<string> PopulateTestInputs::TokenLabel(const Token &token) {
if (token.has_label()) return {token.label()};
return {};
}
vector<string> PopulateTestInputs::TokenTag(const Token &token) {
if (token.has_tag()) return {token.tag()};
return {};
}
vector<string> PopulateTestInputs::TokenWord(const Token &token) {
if (token.has_word()) return {token.word()};
return {};
}
string PopulateTestInputs::AddPart(TaskInput *input, const string &file_format,
const string &record_format) {
string file_name =
tensorflow::strings::StrCat(
tensorflow::testing::TmpDir(), input->name());
auto *part = CHECK_NOTNULL(input)->add_part();
part->set_file_pattern(file_name);
part->set_file_format(file_format);
part->set_record_format(record_format);
return file_name;
}
} // namespace syntaxnet
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A utility for populating a set of inputs of a task. This knows how to create
// tag-map, category-map, label-map and has hooks to
// populate other kinds of inputs. The expected set of operations are:
//
// Sentence document_for_init = ...;
// TaskContext context;
// context->SetParameter("my_parameter", "true");
// MyDocumentProcessor processor;
// processor.Setup(&context);
// PopulateTestInputs::Defaults(document_for_init).Populate(&context);
// processor.Init(&context);
//
// This will check the inputs requested by the processor's Setup(TaskContext *)
// function, and files corresponding to them. For example, if the processor
// asked for the a "tag-map" input, it will create a TermFrequencyMap, populate
// it with the POS tags found in the Sentence document_for_init, save it to disk
// and update the TaskContext with the location of the file. By convention, the
// location is the name of the input. Conceptually, the logic is very simple:
//
// for (TaskInput &input : context->mutable_spec()->mutable_input()) {
// creators[input.name()](&input);
// // check for missing inputs, incompatible formats, etc...
// }
//
// The Populate() routine will also check compatability between requested and
// supplied formats. The Default mapping knows how to populate the following
// inputs:
//
// - category-map: TermFrequencyMap containing POS categories.
//
// - label-map: TermFrequencyMap containing parser labels.
//
// - tag-map: TermFrequencyMap containing POS tags.
//
// - tag-to-category: StringToStringMap mapping POS tags to categories.
//
// - word-map: TermFrequencyMap containing words.
//
// Clients can add creation routines by defining a std::function:
//
// auto creators = PopulateTestInputs::Defaults(document_for_init);
// creators["my-input"] = [](TaskInput *input) { ...; }
//
// See also creators.Add() for more convenience functions.
#ifndef $TARGETDIR_POPULATE_TEST_INPUTS_H_
#define $TARGETDIR_POPULATE_TEST_INPUTS_H_
#include <functional>
#include <string>
#include <unordered_map>
#include <vector>
#include "syntaxnet/utils.h"
namespace syntaxnet {
class Sentence;
class TaskContext;
class TaskInput;
class TaskOutput;
class Token;
class PopulateTestInputs {
public:
// When called, Create() should populate an input by creating a file and
// adding one or more parts to the TaskInput.
typedef std::function<void(TaskInput *)> Create;
// When called, CreateFile() should create a file resource at the given
// path. These are typically less inconvient to write.
typedef std::function<void(const string &)> CreateFile;
// A set of creators, one for each input in a TaskContext.
class CreatorMap : public std::unordered_map<string, Create> {
public:
// A simplified way to add a single-file creator. The name of the file
// location will be file::JoinPath(FLAGS_test_tmpdir, name).
void Add(const string &name, const string &file_format,
const string &record_format, CreateFile makefile);
// Convenience method to populate the inputs in context. Returns true if it
// was possible to populate each input, and false otherwise. If a mandatory
// input does not have a creator, then we LOG(FATAL).
bool Populate(TaskContext *context) const;
};
// Default creator set. This knows how to generate from a given Document
// - category-map
// - label-map
// - tag-map
// - tag-to-category
// - word-map
//
// Note: the default creators capture the document input by value: this means
// that subsequent modifications to the document will NOT be
// reflected in the inputs. However, the following is perfectly valid:
//
// CreatorMap creators;
// {
// Sentence document;
// creators = PopulateTestInputs::Defaults(document);
// }
// creators.Populate(context);
static CreatorMap Defaults(const Sentence &document);
// Populates the TaskContext object from a map of creator functions. Note that
// this static version is compatible with any hash map of the correct type.
static bool Populate(const std::unordered_map<string, Create> &creator_map,
TaskContext *context);
// Helper function for creating a term frequency map from a document. This
// iterates over all the tokens in the document, calls token2str on each
// token, and adds each returned string to the term frequency map. The map is
// then saved to FLAGS_test_tmpdir/name.
static Create CreateTFMapFromDocumentTokens(
const Sentence &document,
std::function<vector<string>(const Token &)> token2str);
// Creates a StringToStringMap protocol buffer input that maps tags to
// categories. Uses whatever mapping is present in the document.
static Create CreateTagToCategoryFromTokens(const Sentence &document);
// Default implementations for "token2str" above.
static vector<string> TokenCategory(const Token &token);
static vector<string> TokenLabel(const Token &token);
static vector<string> TokenTag(const Token &token);
static vector<string> TokenWord(const Token &token);
// Utility function. Sets the TaskInput->part() fields for a new input part.
// Returns the file name.
static string AddPart(TaskInput *input, const string &file_format,
const string &record_format);
};
} // namespace syntaxnet
#endif // $TARGETDIR_POPULATE_TEST_INPUTS_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment