Unverified Commit 80178fc6 authored by Mark Omernick's avatar Mark Omernick Committed by GitHub
Browse files

Merge pull request #4153 from terryykoo/master

Export @195097388.
parents a84e1ef9 edea2b67
...@@ -250,6 +250,7 @@ class GenericFeatureFunction { ...@@ -250,6 +250,7 @@ class GenericFeatureFunction {
string GetParameter(const string &name) const; string GetParameter(const string &name) const;
int GetIntParameter(const string &name, int default_value) const; int GetIntParameter(const string &name, int default_value) const;
bool GetBoolParameter(const string &name, bool default_value) const; bool GetBoolParameter(const string &name, bool default_value) const;
double GetFloatParameter(const string &name, double default_value) const;
// Returns the FML function description for the feature function, i.e. the // Returns the FML function description for the feature function, i.e. the
// name and parameters without the nested features. // name and parameters without the nested features.
......
...@@ -108,6 +108,10 @@ class FMLParser { ...@@ -108,6 +108,10 @@ class FMLParser {
string item_text_; string item_text_;
}; };
// Returns the |function| or |extractor| descriptor as an FML string.
string AsFML(const FeatureFunctionDescriptor &function);
string AsFML(const FeatureExtractorDescriptor &extractor);
} // namespace syntaxnet } // namespace syntaxnet
#endif // SYNTAXNET_FML_PARSER_H_ #endif // SYNTAXNET_FML_PARSER_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/fml_parser.h"
#include <string>
#include <vector>
#include "syntaxnet/base.h"
#include "syntaxnet/feature_extractor.pb.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace {
// Returns the list of lines in the |text|. Also strips trailing whitespace
// from each line, since the FML generator sometimes appends trailing spaces.
std::vector<string> LinesOf(const string &text) {
std::vector<string> lines = tensorflow::str_util::Split(
text, "\n", tensorflow::str_util::SkipEmpty());
for (string &line : lines) {
tensorflow::str_util::StripTrailingWhitespace(&line);
}
return lines;
}
// Tests that a single function can be round-trip converted from FML to
// descriptor protos and back to FML.
TEST(FMLParserTest, RoundTripSingleFunction) {
FeatureExtractorDescriptor extractor;
FMLParser().Parse("offset(1).input.token.word(min-freq=10)", &extractor);
EXPECT_EQ(LinesOf(AsFML(extractor)),
LinesOf("offset(1).input.token.word(min-freq=\"10\")"));
// Also check each individual feature function.
EXPECT_EQ(AsFML(extractor.feature(0)),
"offset(1).input.token.word(min-freq=\"10\")");
EXPECT_EQ(AsFML(extractor.feature(0).feature(0)),
"input.token.word(min-freq=\"10\")");
EXPECT_EQ(AsFML(extractor.feature(0).feature(0).feature(0)),
"token.word(min-freq=\"10\")");
EXPECT_EQ(AsFML(extractor.feature(0).feature(0).feature(0).feature(0)),
"word(min-freq=\"10\")");
}
// Tests that a set of functions can be round-trip converted from FML to
// descriptor protos and back to FML.
TEST(FMLParserTest, RoundTripMultipleFunctions) {
FeatureExtractorDescriptor extractor;
FMLParser().Parse(R"(offset(1).word(max-num-terms=987)
input { tag(outside=false) label }
pairs { stack.tag input.tag input.child(-1).label })",
&extractor);
// Note that AsFML() adds quotes to all feature option values.
EXPECT_EQ(LinesOf(AsFML(extractor)),
LinesOf("offset(1).word(max-num-terms=\"987\")\n"
"input { tag(outside=\"false\") label }\n"
"pairs { stack.tag input.tag input.child(-1).label }"));
}
} // namespace
} // namespace syntaxnet
...@@ -22,6 +22,7 @@ import syntaxnet.load_parser_ops ...@@ -22,6 +22,7 @@ import syntaxnet.load_parser_ops
from tensorflow.python.ops import control_flow_ops as cf from tensorflow.python.ops import control_flow_ops as cf
from tensorflow.python.ops import state_ops from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as tf_saver
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
...@@ -572,5 +573,6 @@ class GreedyParser(object): ...@@ -572,5 +573,6 @@ class GreedyParser(object):
for key in variables_to_save.keys(): for key in variables_to_save.keys():
if not key.endswith('avg_var'): if not key.endswith('avg_var'):
del variables_to_save[key] del variables_to_save[key]
self.saver = tf.train.Saver(variables_to_save) self.saver = tf.train.Saver(
variables_to_save, builder=tf_saver.BaseSaverBuilder())
return self.saver return self.saver
...@@ -20,33 +20,26 @@ ...@@ -20,33 +20,26 @@
import os.path import os.path
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from syntaxnet import graph_builder from syntaxnet import graph_builder
from syntaxnet import sparse_pb2 from syntaxnet import sparse_pb2
from syntaxnet import test_flags
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
FLAGS = tf.app.flags.FLAGS
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
class GraphBuilderTest(tf.test.TestCase):
class GraphBuilderTest(test_util.TensorFlowTestCase):
def setUp(self): def setUp(self):
# Creates a task context with the correct testing paths. # Creates a task context with the correct testing paths.
initial_task_context = os.path.join(FLAGS.test_srcdir, initial_task_context = os.path.join(test_flags.source_root(),
'syntaxnet/' 'syntaxnet/'
'testdata/context.pbtxt') 'testdata/context.pbtxt')
self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt') self._task_context = os.path.join(test_flags.temp_dir(), 'context.pbtxt')
with open(initial_task_context, 'r') as fin: with open(initial_task_context, 'r') as fin:
with open(self._task_context, 'w') as fout: with open(self._task_context, 'w') as fout:
fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir) fout.write(fin.read().replace('SRCDIR', test_flags.source_root())
.replace('OUTPATH', FLAGS.test_tmpdir)) .replace('OUTPATH', test_flags.temp_dir()))
# Creates necessary term maps. # Creates necessary term maps.
with self.test_session() as sess: with self.test_session() as sess:
...@@ -320,4 +313,4 @@ class GraphBuilderTest(test_util.TensorFlowTestCase): ...@@ -320,4 +313,4 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
if __name__ == '__main__': if __name__ == '__main__':
googletest.main() tf.test.main()
...@@ -23,16 +23,13 @@ import tensorflow as tf ...@@ -23,16 +23,13 @@ import tensorflow as tf
import syntaxnet.load_parser_ops import syntaxnet.load_parser_ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from syntaxnet import sentence_pb2 from syntaxnet import sentence_pb2
from syntaxnet import task_spec_pb2 from syntaxnet import task_spec_pb2
from syntaxnet import test_flags
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
FLAGS = tf.app.flags.FLAGS
CONLL_DOC1 = u'''1 बात _ n NN _ _ _ _ _ CONLL_DOC1 = u'''1 बात _ n NN _ _ _ _ _
2 गलत _ adj JJ _ _ _ _ _ 2 गलत _ adj JJ _ _ _ _ _
3 हो _ v VM _ _ _ _ _ 3 हो _ v VM _ _ _ _ _
...@@ -75,15 +72,11 @@ CHAR_NGRAMS = u'''^ अ ^ अभ ^ आ ^ आन ^ इ ^ इस $ ^ क ^ ...@@ -75,15 +72,11 @@ CHAR_NGRAMS = u'''^ अ ^ अभ ^ आ ^ आन ^ इ ^ इस $ ^ क ^
COMMENTS = u'# Line with fake comments.' COMMENTS = u'# Line with fake comments.'
class LexiconBuilderTest(test_util.TensorFlowTestCase): class LexiconBuilderTest(tf.test.TestCase):
def setUp(self): def setUp(self):
if not hasattr(FLAGS, 'test_srcdir'): self.corpus_file = os.path.join(test_flags.temp_dir(), 'documents.conll')
FLAGS.test_srcdir = '' self.context_file = os.path.join(test_flags.temp_dir(), 'context.pbtxt')
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
self.corpus_file = os.path.join(FLAGS.test_tmpdir, 'documents.conll')
self.context_file = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
def AddInput(self, name, file_pattern, record_format, context): def AddInput(self, name, file_pattern, record_format, context):
inp = context.input.add() inp = context.input.add()
...@@ -106,7 +99,8 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase): ...@@ -106,7 +99,8 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
'category-map', 'label-map', 'prefix-table', 'category-map', 'label-map', 'prefix-table',
'suffix-table', 'tag-to-category', 'char-map', 'suffix-table', 'tag-to-category', 'char-map',
'char-ngram-map'): 'char-ngram-map'):
self.AddInput(name, os.path.join(FLAGS.test_tmpdir, name), '', context) self.AddInput(name, os.path.join(test_flags.temp_dir(), name), '',
context)
logging.info('Writing context to: %s', self.context_file) logging.info('Writing context to: %s', self.context_file)
with open(self.context_file, 'w') as f: with open(self.context_file, 'w') as f:
f.write(str(context)) f.write(str(context))
...@@ -140,7 +134,7 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase): ...@@ -140,7 +134,7 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
self.assertTrue(last) self.assertTrue(last)
def ValidateTagToCategoryMap(self): def ValidateTagToCategoryMap(self):
with open(os.path.join(FLAGS.test_tmpdir, 'tag-to-category'), 'r') as f: with open(os.path.join(test_flags.temp_dir(), 'tag-to-category'), 'r') as f:
entries = [line.strip().split('\t') for line in f.readlines()] entries = [line.strip().split('\t') for line in f.readlines()]
for tag, category in entries: for tag, category in entries:
self.assertIn(tag, TAGS) self.assertIn(tag, TAGS)
...@@ -148,7 +142,7 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase): ...@@ -148,7 +142,7 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
def LoadMap(self, map_name): def LoadMap(self, map_name):
loaded_map = {} loaded_map = {}
with open(os.path.join(FLAGS.test_tmpdir, map_name), 'r') as f: with open(os.path.join(test_flags.temp_dir(), map_name), 'r') as f:
for line in f: for line in f:
entries = line.strip().split(' ') entries = line.strip().split(' ')
if len(entries) >= 2: if len(entries) >= 2:
...@@ -237,4 +231,4 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase): ...@@ -237,4 +231,4 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
if __name__ == '__main__': if __name__ == '__main__':
googletest.main() tf.test.main()
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "syntaxnet/ops/shape_helpers.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace syntaxnet { namespace syntaxnet {
...@@ -29,6 +31,14 @@ REGISTER_OP("GoldParseReader") ...@@ -29,6 +31,14 @@ REGISTER_OP("GoldParseReader")
.Attr("corpus_name: string='documents'") .Attr("corpus_name: string='documents'")
.Attr("arg_prefix: string='brain_parser'") .Attr("arg_prefix: string='brain_parser'")
.SetIsStateful() .SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
int feature_size;
TF_RETURN_IF_ERROR(context->GetAttr("feature_size", &feature_size));
for (int i = 0; i < feature_size; ++i) MatrixOutputShape(i, context);
ScalarOutputShape(feature_size, context);
VectorOutputShape(feature_size + 1, context);
return tensorflow::Status::OK();
})
.Doc(R"doc( .Doc(R"doc(
Reads sentences, parses them, and returns (gold action, feature) pairs. Reads sentences, parses them, and returns (gold action, feature) pairs.
...@@ -55,6 +65,15 @@ REGISTER_OP("DecodedParseReader") ...@@ -55,6 +65,15 @@ REGISTER_OP("DecodedParseReader")
.Attr("corpus_name: string='documents'") .Attr("corpus_name: string='documents'")
.Attr("arg_prefix: string='brain_parser'") .Attr("arg_prefix: string='brain_parser'")
.SetIsStateful() .SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
int feature_size;
TF_RETURN_IF_ERROR(context->GetAttr("feature_size", &feature_size));
for (int i = 0; i < feature_size; ++i) MatrixOutputShape(i, context);
ScalarOutputShape(feature_size, context);
context->set_output(feature_size + 1, context->Vector(2));
VectorOutputShape(feature_size + 2, context);
return MatrixInputShape(0, context);
})
.Doc(R"doc( .Doc(R"doc(
Reads sentences and parses them taking parsing transitions based on the Reads sentences and parses them taking parsing transitions based on the
input transition scores. input transition scores.
...@@ -85,6 +104,14 @@ REGISTER_OP("BeamParseReader") ...@@ -85,6 +104,14 @@ REGISTER_OP("BeamParseReader")
.Attr("continue_until_all_final: bool=false") .Attr("continue_until_all_final: bool=false")
.Attr("always_start_new_sentences: bool=false") .Attr("always_start_new_sentences: bool=false")
.SetIsStateful() .SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
int feature_size;
TF_RETURN_IF_ERROR(context->GetAttr("feature_size", &feature_size));
for (int i = 0; i < feature_size; ++i) MatrixOutputShape(i, context);
ScalarOutputShape(feature_size, context);
ScalarOutputShape(feature_size + 1, context);
return tensorflow::Status::OK();
})
.Doc(R"doc( .Doc(R"doc(
Reads sentences and creates a beam parser. Reads sentences and creates a beam parser.
...@@ -112,6 +139,15 @@ REGISTER_OP("BeamParser") ...@@ -112,6 +139,15 @@ REGISTER_OP("BeamParser")
.Output("alive: bool") .Output("alive: bool")
.Attr("feature_size: int") .Attr("feature_size: int")
.SetIsStateful() .SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
int feature_size;
TF_RETURN_IF_ERROR(context->GetAttr("feature_size", &feature_size));
for (int i = 0; i < feature_size; ++i) MatrixOutputShape(i, context);
ScalarOutputShape(feature_size, context);
VectorOutputShape(feature_size + 1, context);
TF_RETURN_IF_ERROR(ScalarInputShape(0, context));
return MatrixInputShape(1, context);
})
.Doc(R"doc( .Doc(R"doc(
Updates the beam parser based on scores in the input transition scores. Updates the beam parser based on scores in the input transition scores.
...@@ -131,6 +167,13 @@ REGISTER_OP("BeamParserOutput") ...@@ -131,6 +167,13 @@ REGISTER_OP("BeamParserOutput")
.Output("gold_slot: int32") .Output("gold_slot: int32")
.Output("path_scores: float") .Output("path_scores: float")
.SetIsStateful() .SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
context->set_output(0, context->Matrix(2, context->UnknownDim()));
context->set_output(1, context->Matrix(2, context->UnknownDim()));
VectorOutputShape(2, context);
VectorOutputShape(3, context);
return ScalarInputShape(0, context);
})
.Doc(R"doc( .Doc(R"doc(
Converts the current state of the beam parser into a set of indices into Converts the current state of the beam parser into a set of indices into
the scoring matrices that lead there. the scoring matrices that lead there.
...@@ -152,6 +195,11 @@ REGISTER_OP("BeamEvalOutput") ...@@ -152,6 +195,11 @@ REGISTER_OP("BeamEvalOutput")
.Output("eval_metrics: int32") .Output("eval_metrics: int32")
.Output("documents: string") .Output("documents: string")
.SetIsStateful() .SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
context->set_output(0, context->Vector(2));
VectorOutputShape(1, context);
return ScalarInputShape(0, context);
})
.Doc(R"doc( .Doc(R"doc(
Computes eval metrics for the best paths in the input beams. Computes eval metrics for the best paths in the input beams.
...@@ -192,6 +240,13 @@ REGISTER_OP("FeatureSize") ...@@ -192,6 +240,13 @@ REGISTER_OP("FeatureSize")
.Output("embedding_dims: int32") .Output("embedding_dims: int32")
.Output("num_actions: int32") .Output("num_actions: int32")
.Attr("arg_prefix: string='brain_parser'") .Attr("arg_prefix: string='brain_parser'")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
VectorOutputShape(1, context);
VectorOutputShape(2, context);
ScalarOutputShape(3, context);
return tensorflow::Status::OK();
})
.Doc(R"doc( .Doc(R"doc(
An op that returns the number and domain sizes of parser features. An op that returns the number and domain sizes of parser features.
...@@ -210,6 +265,10 @@ REGISTER_OP("FeatureVocab") ...@@ -210,6 +265,10 @@ REGISTER_OP("FeatureVocab")
.Attr("arg_prefix: string='brain_parser'") .Attr("arg_prefix: string='brain_parser'")
.Attr("embedding_name: string='words'") .Attr("embedding_name: string='words'")
.Output("vocab: string") .Output("vocab: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
return tensorflow::Status::OK();
})
.Doc(R"doc( .Doc(R"doc(
Returns the vocabulary of the parser features for a particular named channel. Returns the vocabulary of the parser features for a particular named channel.
For "words" this would would be the entire vocabulary, plus any special tokens For "words" this would would be the entire vocabulary, plus any special tokens
...@@ -227,6 +286,12 @@ REGISTER_OP("UnpackSyntaxNetSparseFeatures") ...@@ -227,6 +286,12 @@ REGISTER_OP("UnpackSyntaxNetSparseFeatures")
.Output("indices: int32") .Output("indices: int32")
.Output("ids: int64") .Output("ids: int64")
.Output("weights: float") .Output("weights: float")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
VectorOutputShape(1, context);
VectorOutputShape(2, context);
return VectorInputShape(0, context);
})
.Doc(R"doc( .Doc(R"doc(
Converts a vector of strings with SparseFeatures to tensors. Converts a vector of strings with SparseFeatures to tensors.
...@@ -249,11 +314,16 @@ REGISTER_OP("WordEmbeddingInitializer") ...@@ -249,11 +314,16 @@ REGISTER_OP("WordEmbeddingInitializer")
.Attr("vectors: string") .Attr("vectors: string")
.Attr("task_context: string = ''") .Attr("task_context: string = ''")
.Attr("vocabulary: string = ''") .Attr("vocabulary: string = ''")
.Attr("override_num_embeddings: int = -1")
.Attr("cache_vectors_locally: bool = true") .Attr("cache_vectors_locally: bool = true")
.Attr("num_special_embeddings: int = 3") .Attr("num_special_embeddings: int = 3")
.Attr("embedding_init: float = 1.0") .Attr("embedding_init: float = 1.0")
.Attr("seed: int = 0") .Attr("seed: int = 0")
.Attr("seed2: int = 0") .Attr("seed2: int = 0")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
MatrixOutputShape(0, context);
return tensorflow::Status::OK();
})
.Doc(R"doc( .Doc(R"doc(
Reads word embeddings from an sstable of dist_belief.TokenEmbedding protos for Reads word embeddings from an sstable of dist_belief.TokenEmbedding protos for
every word specified in a text vocabulary file. every word specified in a text vocabulary file.
...@@ -264,6 +334,10 @@ task_context: file path at which to read the task context, for its "word-map" ...@@ -264,6 +334,10 @@ task_context: file path at which to read the task context, for its "word-map"
input. Exactly one of `task_context` or `vocabulary` must be specified. input. Exactly one of `task_context` or `vocabulary` must be specified.
vocabulary: path to vocabulary file, which contains one unique word per line, in vocabulary: path to vocabulary file, which contains one unique word per line, in
order. Exactly one of `task_context` or `vocabulary` must be specified. order. Exactly one of `task_context` or `vocabulary` must be specified.
override_num_embeddings: Number of rows in the returned embedding matrix. If
override_num_embeddings is larger than 0, then the returned embedding matrix
has override_num_embeddings_ rows. Otherwise, the number of rows of the
returned embedding matrix is |vocabulary| + num_special_embeddings.
cache_vectors_locally: Whether to cache the vectors file to a local temp file cache_vectors_locally: Whether to cache the vectors file to a local temp file
before parsing it. This greatly reduces initialization time when the vectors before parsing it. This greatly reduces initialization time when the vectors
are stored remotely, but requires that "/tmp" has sufficient space. are stored remotely, but requires that "/tmp" has sufficient space.
...@@ -286,6 +360,11 @@ REGISTER_OP("DocumentSource") ...@@ -286,6 +360,11 @@ REGISTER_OP("DocumentSource")
.Attr("corpus_name: string='documents'") .Attr("corpus_name: string='documents'")
.Attr("batch_size: int") .Attr("batch_size: int")
.SetIsStateful() .SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
ScalarOutputShape(1, context);
return tensorflow::Status::OK();
})
.Doc(R"doc( .Doc(R"doc(
Reads documents from documents_path and outputs them. Reads documents from documents_path and outputs them.
...@@ -301,6 +380,9 @@ REGISTER_OP("DocumentSink") ...@@ -301,6 +380,9 @@ REGISTER_OP("DocumentSink")
.Attr("task_context: string=''") .Attr("task_context: string=''")
.Attr("task_context_str: string=''") .Attr("task_context_str: string=''")
.Attr("corpus_name: string='documents'") .Attr("corpus_name: string='documents'")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
return VectorInputShape(0, context);
})
.Doc(R"doc( .Doc(R"doc(
Write documents to documents_path. Write documents to documents_path.
...@@ -312,6 +394,10 @@ task_context_str: a task context in text format, used if task_context is empty. ...@@ -312,6 +394,10 @@ task_context_str: a task context in text format, used if task_context is empty.
REGISTER_OP("SegmenterTrainingDataConstructor") REGISTER_OP("SegmenterTrainingDataConstructor")
.Input("documents: string") .Input("documents: string")
.Output("char_doc: string") .Output("char_doc: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
return VectorInputShape(0, context);
})
.Doc(R"doc( .Doc(R"doc(
Constructs segmentation training data from documents with gold segmentation. Constructs segmentation training data from documents with gold segmentation.
...@@ -322,6 +408,10 @@ char_doc: a vector of documents as serialized protos. ...@@ -322,6 +408,10 @@ char_doc: a vector of documents as serialized protos.
REGISTER_OP("CharTokenGenerator") REGISTER_OP("CharTokenGenerator")
.Input("documents: string") .Input("documents: string")
.Output("char_doc: string") .Output("char_doc: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
return VectorInputShape(0, context);
})
.Doc(R"doc( .Doc(R"doc(
Converts token field of the input documents such that each token in the Converts token field of the input documents such that each token in the
output doc is a utf-8 character from that doc's text. output doc is a utf-8 character from that doc's text.
...@@ -337,6 +427,10 @@ REGISTER_OP("WellFormedFilter") ...@@ -337,6 +427,10 @@ REGISTER_OP("WellFormedFilter")
.Attr("task_context_str: string=''") .Attr("task_context_str: string=''")
.Attr("corpus_name: string='documents'") .Attr("corpus_name: string='documents'")
.Attr("keep_malformed_documents: bool = False") .Attr("keep_malformed_documents: bool = False")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
return VectorInputShape(0, context);
})
.Doc(R"doc( .Doc(R"doc(
Removes sentences with malformed parse trees, i.e. they contain cycles. Removes sentences with malformed parse trees, i.e. they contain cycles.
...@@ -353,6 +447,10 @@ REGISTER_OP("ProjectivizeFilter") ...@@ -353,6 +447,10 @@ REGISTER_OP("ProjectivizeFilter")
.Attr("task_context_str: string=''") .Attr("task_context_str: string=''")
.Attr("corpus_name: string='documents'") .Attr("corpus_name: string='documents'")
.Attr("discard_non_projective: bool = False") .Attr("discard_non_projective: bool = False")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
return VectorInputShape(0, context);
})
.Doc(R"doc( .Doc(R"doc(
Modifies input parse trees to make them projective. Modifies input parse trees to make them projective.
......
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Shape inference functions for SyntaxNet ops.
#ifndef SYNTAXNET_OPS_SHAPE_HELPERS_H_
#define SYNTAXNET_OPS_SHAPE_HELPERS_H_
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
// Returns OK if the |input_index|'th input is a tensor of the |rank| with
// unknown dimensions.
inline tensorflow::Status TensorInputShape(
int input_index, int rank,
tensorflow::shape_inference::InferenceContext *context) {
tensorflow::shape_inference::ShapeHandle unused;
return context->WithRank(context->input(input_index), rank, &unused);
}
// Returns OK if the |input_index|'th input is a scalar.
inline tensorflow::Status ScalarInputShape(
int input_index, tensorflow::shape_inference::InferenceContext *context) {
return TensorInputShape(input_index, 0, context);
}
// Returns OK if the |input_index|'th input is a vector of unknown dimension.
inline tensorflow::Status VectorInputShape(
int input_index, tensorflow::shape_inference::InferenceContext *context) {
return TensorInputShape(input_index, 1, context);
}
// Returns OK if the |input_index|'th input is a matrix of unknown dimensions.
inline tensorflow::Status MatrixInputShape(
int input_index, tensorflow::shape_inference::InferenceContext *context) {
return TensorInputShape(input_index, 2, context);
}
// Sets the |output_index|'th output to a scalar.
inline void ScalarOutputShape(
int output_index, tensorflow::shape_inference::InferenceContext *context) {
context->set_output(output_index, context->Scalar());
}
// Sets the |output_index|'th output to a vector of unknown dimension.
inline void VectorOutputShape(
int output_index, tensorflow::shape_inference::InferenceContext *context) {
context->set_output(output_index, context->UnknownShapeOfRank(1));
}
// Sets the |output_index|'th output to a matrix of unknown dimensions.
inline void MatrixOutputShape(
int output_index, tensorflow::shape_inference::InferenceContext *context) {
context->set_output(output_index, context->UnknownShapeOfRank(2));
}
} // namespace syntaxnet
#endif // SYNTAXNET_OPS_SHAPE_HELPERS_H_
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
import os import os
import os.path import os.path
import time import time
from absl import app
from absl import flags
import tempfile import tempfile
import tensorflow as tf import tensorflow as tf
...@@ -33,7 +35,6 @@ from syntaxnet import structured_graph_builder ...@@ -33,7 +35,6 @@ from syntaxnet import structured_graph_builder
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
from syntaxnet import task_spec_pb2 from syntaxnet import task_spec_pb2
flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -158,4 +159,4 @@ def main(unused_argv): ...@@ -158,4 +159,4 @@ def main(unused_argv):
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() app.run(main)
...@@ -331,24 +331,6 @@ class LastActionFeatureFunction : public ParserFeatureFunction { ...@@ -331,24 +331,6 @@ class LastActionFeatureFunction : public ParserFeatureFunction {
REGISTER_PARSER_FEATURE_FUNCTION("last-action", LastActionFeatureFunction); REGISTER_PARSER_FEATURE_FUNCTION("last-action", LastActionFeatureFunction);
class Constant : public ParserFeatureFunction {
public:
void Init(TaskContext *context) override {
value_ = this->GetIntParameter("value", 0);
this->set_feature_type(new NumericFeatureType(this->name(), value_ + 1));
}
// Returns the constant's value.
FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
const FeatureVector *result) const override {
return value_;
}
private:
int value_ = 0;
};
REGISTER_PARSER_FEATURE_FUNCTION("constant", Constant);
// Register the generic parser features. // Register the generic parser features.
typedef GenericFeatures<ParserState> GenericParserFeature; typedef GenericFeatures<ParserState> GenericParserFeature;
REGISTER_SYNTAXNET_GENERIC_FEATURES(GenericParserFeature); REGISTER_SYNTAXNET_GENERIC_FEATURES(GenericParserFeature);
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
import os import os
import os.path import os.path
import time import time
from absl import app
from absl import flags
import tensorflow as tf import tensorflow as tf
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
...@@ -32,7 +34,6 @@ from syntaxnet import structured_graph_builder ...@@ -32,7 +34,6 @@ from syntaxnet import structured_graph_builder
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
from syntaxnet import task_spec_pb2 from syntaxnet import task_spec_pb2
flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('tf_master', '', flags.DEFINE_string('tf_master', '',
...@@ -299,4 +300,4 @@ def main(unused_argv): ...@@ -299,4 +300,4 @@ def main(unused_argv):
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() app.run(main)
...@@ -453,6 +453,8 @@ class WordEmbeddingInitializer : public OpKernel { ...@@ -453,6 +453,8 @@ class WordEmbeddingInitializer : public OpKernel {
&cache_vectors_locally_)); &cache_vectors_locally_));
OP_REQUIRES_OK(context, context->GetAttr("num_special_embeddings", OP_REQUIRES_OK(context, context->GetAttr("num_special_embeddings",
&num_special_embeddings_)); &num_special_embeddings_));
OP_REQUIRES_OK(context, context->GetAttr("override_num_embeddings",
&override_num_embeddings_));
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->GetAttr("embedding_init", &embedding_init_)); context->GetAttr("embedding_init", &embedding_init_));
...@@ -569,7 +571,13 @@ class WordEmbeddingInitializer : public OpKernel { ...@@ -569,7 +571,13 @@ class WordEmbeddingInitializer : public OpKernel {
const std::unordered_map<string, int64> &vocabulary, const std::unordered_map<string, int64> &vocabulary,
const TokenEmbedding &embedding, OpKernelContext *context, const TokenEmbedding &embedding, OpKernelContext *context,
Tensor **embedding_matrix) const { Tensor **embedding_matrix) const {
const int rows = vocabulary.size() + num_special_embeddings_; const int rows = override_num_embeddings_ > 0 ? override_num_embeddings_ :
(vocabulary.size() + num_special_embeddings_);
if (rows < vocabulary.size()) {
return InvalidArgument(
"Embedding matrix row number ", rows,
" is less than vocabulary size ", vocabulary.size());
}
const int columns = embedding.vector().values_size(); const int columns = embedding.vector().values_size();
TF_RETURN_IF_ERROR(context->allocate_output(0, TensorShape({rows, columns}), TF_RETURN_IF_ERROR(context->allocate_output(0, TensorShape({rows, columns}),
embedding_matrix)); embedding_matrix));
...@@ -637,6 +645,11 @@ class WordEmbeddingInitializer : public OpKernel { ...@@ -637,6 +645,11 @@ class WordEmbeddingInitializer : public OpKernel {
// Number of special embeddings to allocate. // Number of special embeddings to allocate.
int num_special_embeddings_ = 3; int num_special_embeddings_ = 3;
// If override_num_embeddings_ is larger than zero, then the returned
// embedding matrix has override_num_embeddings_ of rows. Otherwise, the
// number of rows equals to |vocabulary| + num_special_embeddigs_.
int override_num_embeddings_ = -1;
// Seed for random initialization. // Seed for random initialization.
uint64 seed_ = 0; uint64 seed_ = 0;
......
...@@ -20,35 +20,27 @@ import os.path ...@@ -20,35 +20,27 @@ import os.path
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from syntaxnet import dictionary_pb2 from syntaxnet import dictionary_pb2
from syntaxnet import graph_builder from syntaxnet import graph_builder
from syntaxnet import sparse_pb2 from syntaxnet import sparse_pb2
from syntaxnet import test_flags
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
FLAGS = tf.app.flags.FLAGS class ParsingReaderOpsTest(tf.test.TestCase):
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
def setUp(self): def setUp(self):
# Creates a task context with the correct testing paths. # Creates a task context with the correct testing paths.
initial_task_context = os.path.join(FLAGS.test_srcdir, initial_task_context = os.path.join(test_flags.source_root(),
'syntaxnet/' 'syntaxnet/'
'testdata/context.pbtxt') 'testdata/context.pbtxt')
self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt') self._task_context = os.path.join(test_flags.temp_dir(), 'context.pbtxt')
with open(initial_task_context, 'r') as fin: with open(initial_task_context, 'r') as fin:
with open(self._task_context, 'w') as fout: with open(self._task_context, 'w') as fout:
fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir) fout.write(fin.read().replace('SRCDIR', test_flags.source_root())
.replace('OUTPATH', FLAGS.test_tmpdir)) .replace('OUTPATH', test_flags.temp_dir()))
# Creates necessary term maps. # Creates necessary term maps.
with self.test_session() as sess: with self.test_session() as sess:
...@@ -175,7 +167,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase): ...@@ -175,7 +167,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
def testWordEmbeddingInitializer(self): def testWordEmbeddingInitializer(self):
# Provide embeddings for the first three words in the word map. # Provide embeddings for the first three words in the word map.
records_path = os.path.join(FLAGS.test_tmpdir, 'records1') records_path = os.path.join(test_flags.temp_dir(), 'records1')
writer = tf.python_io.TFRecordWriter(records_path) writer = tf.python_io.TFRecordWriter(records_path)
writer.write(self._token_embedding('.', [1, 2])) writer.write(self._token_embedding('.', [1, 2]))
writer.write(self._token_embedding(',', [3, 4])) writer.write(self._token_embedding(',', [3, 4]))
...@@ -193,7 +185,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase): ...@@ -193,7 +185,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
embeddings[:3,]) embeddings[:3,])
def testWordEmbeddingInitializerRepeatability(self): def testWordEmbeddingInitializerRepeatability(self):
records_path = os.path.join(FLAGS.test_tmpdir, 'records2') records_path = os.path.join(test_flags.temp_dir(), 'records2')
writer = tf.python_io.TFRecordWriter(records_path) writer = tf.python_io.TFRecordWriter(records_path)
writer.write(self._token_embedding('.', [1, 2, 3])) # 3 dims writer.write(self._token_embedding('.', [1, 2, 3])) # 3 dims
del writer del writer
...@@ -234,7 +226,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase): ...@@ -234,7 +226,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
vocabulary='/dev/null').eval() vocabulary='/dev/null').eval()
def testWordEmbeddingInitializerVocabularyFile(self): def testWordEmbeddingInitializerVocabularyFile(self):
records_path = os.path.join(FLAGS.test_tmpdir, 'records3') records_path = os.path.join(test_flags.temp_dir(), 'records3')
writer = tf.python_io.TFRecordWriter(records_path) writer = tf.python_io.TFRecordWriter(records_path)
writer.write(self._token_embedding('a', [1, 2, 3])) writer.write(self._token_embedding('a', [1, 2, 3]))
writer.write(self._token_embedding('b', [2, 3, 4])) writer.write(self._token_embedding('b', [2, 3, 4]))
...@@ -243,7 +235,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase): ...@@ -243,7 +235,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
writer.write(self._token_embedding('e', [5, 6, 7])) writer.write(self._token_embedding('e', [5, 6, 7]))
del writer del writer
vocabulary_path = os.path.join(FLAGS.test_tmpdir, 'vocabulary3') vocabulary_path = os.path.join(test_flags.temp_dir(), 'vocabulary3')
with open(vocabulary_path, 'w') as vocabulary_file: with open(vocabulary_path, 'w') as vocabulary_file:
vocabulary_file.write('a\nc\ne\nx\n') # 'x' not in pretrained embeddings vocabulary_file.write('a\nc\ne\nx\n') # 'x' not in pretrained embeddings
...@@ -271,8 +263,50 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase): ...@@ -271,8 +263,50 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
[5.0 / norm_e, 6.0 / norm_e, 7.0 / norm_e]], [5.0 / norm_e, 6.0 / norm_e, 7.0 / norm_e]],
embeddings[:3].eval()) embeddings[:3].eval())
def testWordEmbeddingInitializerPresetRowNumber(self):
records_path = os.path.join(test_flags.temp_dir(), 'records3')
writer = tf.python_io.TFRecordWriter(records_path)
writer.write(self._token_embedding('a', [1, 2, 3]))
writer.write(self._token_embedding('b', [2, 3, 4]))
writer.write(self._token_embedding('c', [3, 4, 5]))
writer.write(self._token_embedding('d', [4, 5, 6]))
writer.write(self._token_embedding('e', [5, 6, 7]))
del writer
vocabulary_path = os.path.join(test_flags.temp_dir(), 'vocabulary3')
with open(vocabulary_path, 'w') as vocabulary_file:
vocabulary_file.write('a\nc\ne\nx\n') # 'x' not in pretrained embeddings
# Enumerate a variety of configurations.
for cache_vectors_locally in [False, True]:
for num_special_embeddings in [None, 1, 2, 5]: # None = use default of 3
for override_num_embeddings in [-1, 8, 10]:
with self.test_session():
embeddings = gen_parser_ops.word_embedding_initializer(
vectors=records_path,
vocabulary=vocabulary_path,
override_num_embeddings=override_num_embeddings,
cache_vectors_locally=cache_vectors_locally,
num_special_embeddings=num_special_embeddings)
# Expect 4 embeddings from the vocabulary plus special embeddings.
expected_num_embeddings = 4 + (num_special_embeddings or 3)
if override_num_embeddings > 0:
expected_num_embeddings = override_num_embeddings
self.assertAllEqual([expected_num_embeddings, 3],
tf.shape(embeddings).eval())
# The first 3 embeddings should be pretrained.
norm_a = (1.0 + 4.0 + 9.0)**0.5
norm_c = (9.0 + 16.0 + 25.0)**0.5
norm_e = (25.0 + 36.0 + 49.0)**0.5
self.assertAllClose([[1.0 / norm_a, 2.0 / norm_a, 3.0 / norm_a], [
3.0 / norm_c, 4.0 / norm_c, 5.0 / norm_c
], [5.0 / norm_e, 6.0 / norm_e, 7.0 / norm_e]],
embeddings[:3].eval())
def testWordEmbeddingInitializerVocabularyFileWithDuplicates(self): def testWordEmbeddingInitializerVocabularyFileWithDuplicates(self):
records_path = os.path.join(FLAGS.test_tmpdir, 'records4') records_path = os.path.join(test_flags.temp_dir(), 'records4')
writer = tf.python_io.TFRecordWriter(records_path) writer = tf.python_io.TFRecordWriter(records_path)
writer.write(self._token_embedding('a', [1, 2, 3])) writer.write(self._token_embedding('a', [1, 2, 3]))
writer.write(self._token_embedding('b', [2, 3, 4])) writer.write(self._token_embedding('b', [2, 3, 4]))
...@@ -281,7 +315,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase): ...@@ -281,7 +315,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
writer.write(self._token_embedding('e', [5, 6, 7])) writer.write(self._token_embedding('e', [5, 6, 7]))
del writer del writer
vocabulary_path = os.path.join(FLAGS.test_tmpdir, 'vocabulary4') vocabulary_path = os.path.join(test_flags.temp_dir(), 'vocabulary4')
with open(vocabulary_path, 'w') as vocabulary_file: with open(vocabulary_path, 'w') as vocabulary_file:
vocabulary_file.write('a\nc\ne\nx\ny\nx') # 'x' duplicated vocabulary_file.write('a\nc\ne\nx\ny\nx') # 'x' duplicated
...@@ -292,4 +326,4 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase): ...@@ -292,4 +326,4 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
if __name__ == '__main__': if __name__ == '__main__':
googletest.main() tf.test.main()
...@@ -15,6 +15,12 @@ limitations under the License. ...@@ -15,6 +15,12 @@ limitations under the License.
#include "syntaxnet/registry.h" #include "syntaxnet/registry.h"
#include <set>
#include <string>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet { namespace syntaxnet {
// Global list of all component registries. // Global list of all component registries.
...@@ -25,4 +31,35 @@ void RegistryMetadata::Register(RegistryMetadata *registry) { ...@@ -25,4 +31,35 @@ void RegistryMetadata::Register(RegistryMetadata *registry) {
global_registry_list = registry; global_registry_list = registry;
} }
string ComponentMetadata::DebugString() const {
return tensorflow::strings::StrCat("Registered '", name_, "' as class ",
class_name_, " at ", file_, ":", line_);
}
tensorflow::Status RegistryMetadata::Validate() {
static const tensorflow::Status *const status =
new tensorflow::Status(ValidateImpl());
return *status;
}
tensorflow::Status RegistryMetadata::ValidateImpl() {
// Iterates over the registries for each type.
for (RegistryMetadata *registry = global_registry_list; registry != nullptr;
registry = static_cast<RegistryMetadata *>(registry->link())) {
std::set<string> names;
// Searches for duplicate names within each component registry.
for (ComponentMetadata *component = *(registry->components_);
component != nullptr; component = component->link()) {
if (!names.insert(component->name()).second) {
return tensorflow::errors::InvalidArgument(
"Multiple classes named '", component->name(),
"' have been registered as ", registry->name(), ": ",
component->DebugString());
}
}
}
return tensorflow::Status::OK();
}
} // namespace syntaxnet } // namespace syntaxnet
...@@ -54,10 +54,13 @@ limitations under the License. ...@@ -54,10 +54,13 @@ limitations under the License.
#define SYNTAXNET_REGISTRY_H_ #define SYNTAXNET_REGISTRY_H_
#include <string.h> #include <string.h>
#include <memory>
#include <string> #include <string>
#include <vector>
#include "syntaxnet/utils.h" #include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet { namespace syntaxnet {
...@@ -75,6 +78,9 @@ class ComponentMetadata { ...@@ -75,6 +78,9 @@ class ComponentMetadata {
// Returns component name. // Returns component name.
const char *name() const { return name_; } const char *name() const { return name_; }
// Returns a human-readable description of this.
string DebugString() const;
// Metadata objects can be linked in a list. // Metadata objects can be linked in a list.
ComponentMetadata *link() const { return link_; } ComponentMetadata *link() const { return link_; }
void set_link(ComponentMetadata *link) { link_ = link; } void set_link(ComponentMetadata *link) { link_ = link; }
...@@ -107,7 +113,16 @@ class RegistryMetadata : public ComponentMetadata { ...@@ -107,7 +113,16 @@ class RegistryMetadata : public ComponentMetadata {
// Registers a component registry in the master registry. // Registers a component registry in the master registry.
static void Register(RegistryMetadata *registry); static void Register(RegistryMetadata *registry);
// Validates the registry; returns non-OK if there are duplicate component
// names of the same type. Situations where this can happen include accidental
// class name collisions, and linking in two different multiarch versions
// of the same component. Repeated calls uses the original result.
static tensorflow::Status Validate();
private: private:
// Implementation for validating the registry.
static tensorflow::Status ValidateImpl();
// Location of list of components in registry. // Location of list of components in registry.
ComponentMetadata **components_; ComponentMetadata **components_;
}; };
...@@ -157,14 +172,21 @@ struct ComponentRegistry { ...@@ -157,14 +172,21 @@ struct ComponentRegistry {
T *object_; T *object_;
}; };
// Finds registrar for named component in registry. // Finds registrar for named component in registry, returning null if not
const Registrar *GetComponent(const char *type) const { // found.
const Registrar *GetComponentOrNull(const char *type) const {
Registrar *r = components; Registrar *r = components;
while (r != nullptr && strcmp(type, r->type()) != 0) r = r->next(); while (r != nullptr && strcmp(type, r->type()) != 0) r = r->next();
if (r == nullptr) { return r;
}
// Finds registrar for named component in registry, raising errors on failure.
const Registrar *GetComponent(const char *type) const {
const Registrar *result = GetComponentOrNull(type);
if (result == nullptr) {
LOG(FATAL) << "Unknown " << name << " component: '" << type << "'."; LOG(FATAL) << "Unknown " << name << " component: '" << type << "'.";
} }
return r; return result;
} }
// Finds a named component in the registry. // Finds a named component in the registry.
...@@ -196,7 +218,24 @@ class RegisterableClass { ...@@ -196,7 +218,24 @@ class RegisterableClass {
typedef ComponentRegistry<Factory> Registry; typedef ComponentRegistry<Factory> Registry;
// Creates a new component instance. // Creates a new component instance.
static T *Create(const string &type) { return registry()->Lookup(type)(); } static T *Create(const string &type) {
TF_CHECK_OK(syntaxnet::RegistryMetadata::Validate());
return registry()->Lookup(type)();
}
static tensorflow::Status CreateOrError(const string &type,
std::unique_ptr<T> *result) {
TF_RETURN_IF_ERROR(syntaxnet::RegistryMetadata::Validate());
const typename Registry::Registrar *component =
registry()->GetComponentOrNull(type.c_str());
if (component == nullptr) {
return tensorflow::errors::NotFound("Unknown ", registry()->name, ": ",
type);
} else {
result->reset(component->object()());
return tensorflow::Status::OK();
}
}
// Returns registry for class. // Returns registry for class.
static Registry *registry() { return &registry_; } static Registry *registry() { return &registry_; }
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "syntaxnet/registry.h"
#include <memory>
#include "dragnn/core/test/generic.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
class ThingDoer : public RegisterableClass<ThingDoer> {};
DECLARE_SYNTAXNET_CLASS_REGISTRY("Thing doer", ThingDoer);
REGISTER_SYNTAXNET_CLASS_REGISTRY("Thing doer", ThingDoer);
class Foo : public ThingDoer {};
class Bar : public ThingDoer {};
class Bar2 : public ThingDoer {};
REGISTER_SYNTAXNET_CLASS_COMPONENT(ThingDoer, "foo", Foo);
REGISTER_SYNTAXNET_CLASS_COMPONENT(ThingDoer, "bar", Bar);
#if DRAGNN_REGISTRY_TEST_WITH_DUPLICATE
REGISTER_SYNTAXNET_CLASS_COMPONENT(ThingDoer, "bar", Bar2); // bad
constexpr char kDuplicateError[] =
"Multiple classes named 'bar' have been registered as Thing doer";
#endif
namespace {
#if !DRAGNN_REGISTRY_TEST_WITH_DUPLICATE
// Tests that CreateOrError() is successful for a properly registered component.
TEST(RegistryTest, CreateOrErrorSuccess) {
std::unique_ptr<ThingDoer> object;
TF_ASSERT_OK(ThingDoer::CreateOrError("foo", &object));
ASSERT_NE(object, nullptr);
}
#else
// Tests that CreateOrError() fails if the registry is misconfigured.
TEST(RegistryTest, CreateOrErrorFailure) {
std::unique_ptr<ThingDoer> object;
EXPECT_THAT(ThingDoer::CreateOrError("bar", &object),
test::IsErrorWithSubstr(kDuplicateError));
ASSERT_EQ(object, nullptr);
// Any call to Create has the same error.
EXPECT_THAT(ThingDoer::CreateOrError("foo", &object),
test::IsErrorWithSubstr(kDuplicateError));
}
// Tests that Create() dies if the registry is misconfigured.
TEST(RegistryTest, CreateFailure) {
EXPECT_DEATH(ThingDoer::Create("bar"), kDuplicateError);
}
#endif
// Tests that CreateOrError() returns error if the component is unknown.
TEST(RegistryTest, CreateOrErrorUnknown) {
std::unique_ptr<ThingDoer> object;
EXPECT_FALSE(ThingDoer::CreateOrError("unknown", &object).ok());
}
// Tests that Validate() returns OK only when the registry is fine.
TEST(RegistryTest, Validate) {
#if DRAGNN_REGISTRY_TEST_WITH_DUPLICATE
EXPECT_THAT(RegistryMetadata::Validate(),
test::IsErrorWithSubstr(kDuplicateError));
#else
TF_EXPECT_OK(RegistryMetadata::Validate());
#endif
}
} // namespace
} // namespace syntaxnet
...@@ -39,7 +39,7 @@ class SharedStore { ...@@ -39,7 +39,7 @@ class SharedStore {
static const T *Get(const string &name, static const T *Get(const string &name,
Args &&...args); // NOLINT(build/c++11) Args &&...args); // NOLINT(build/c++11)
// Like Get(), but creates the object with "closure->Run()". If the closure // Like Get(), but creates the object with "(*closure)()". If the closure
// returns null, we store a null in the SharedStore, but note that Release() // returns null, we store a null in the SharedStore, but note that Release()
// cannot be used to remove it. This is because Release() finds the object // cannot be used to remove it. This is because Release() finds the object
// by associative lookup, and there may be more than one null value, so we // by associative lookup, and there may be more than one null value, so we
......
...@@ -115,9 +115,8 @@ class StructuredGraphBuilder(graph_builder.GreedyParser): ...@@ -115,9 +115,8 @@ class StructuredGraphBuilder(graph_builder.GreedyParser):
return tf.logical_and(args[1] < max_steps, tf.reduce_any(args[3])) return tf.logical_and(args[1] < max_steps, tf.reduce_any(args[3]))
step = tf.constant(0, tf.int32, []) step = tf.constant(0, tf.int32, [])
scores_array = tensor_array_ops.TensorArray(dtype=tf.float32, scores_array = tensor_array_ops.TensorArray(
size=0, dtype=tf.float32, size=0, infer_shape=False, dynamic_size=True)
dynamic_size=True)
alive = tf.constant(True, tf.bool, [batch_size]) alive = tf.constant(True, tf.bool, [batch_size])
alive_steps = tf.constant(0, tf.int32, [batch_size]) alive_steps = tf.constant(0, tf.int32, [batch_size])
t = tf.while_loop( t = tf.while_loop(
......
...@@ -12,99 +12,24 @@ ...@@ -12,99 +12,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Build rules for Syntaxnet."""
load("@protobuf_archive//:protobuf.bzl", "cc_proto_library")
load("@protobuf_archive//:protobuf.bzl", "py_proto_library") load(
"@org_tensorflow//tensorflow/core:platform/default/build_config.bzl",
orig_tf_proto_library_cc = "tf_proto_library_cc",
def if_cuda(if_true, if_false = []): )
"""Shorthand for select()'ing on whether we're building with CUDA.""" load(
return select({ "@org_tensorflow//tensorflow/core:platform/default/build_config.bzl",
"@local_config_cuda//cuda:using_nvcc": if_true, orig_tf_proto_library_py = "tf_proto_library_py",
"@local_config_cuda//cuda:using_clang": if_true, )
"//conditions:default": if_false
}) # For some reason, tf_proto_library_cc() isn't obeying the default_visibility
# directive at the top of the build file. So just set it to public (which it is
def tf_copts(): # anyway).
return (["-fno-exceptions", "-DEIGEN_AVOID_STL_ARRAY",] + def tf_proto_library_cc(name, visibility=[], **kwargs):
if_cuda(["-DGOOGLE_CUDA=1"]) + visibility = visibility if visibility else ["//visibility:public"]
select({"@org_tensorflow//tensorflow:darwin": [], return orig_tf_proto_library_cc(name, visibility=visibility, **kwargs)
"//conditions:default": ["-pthread"]}))
def tf_proto_library_py(name, visibility=[], **kwargs):
def tf_proto_library(name, srcs=[], has_services=False, visibility = visibility if visibility else ["//visibility:public"]
deps=[], visibility=None, testonly=0, return orig_tf_proto_library_py(name, visibility=visibility, **kwargs)
cc_api_version=2, go_api_version=2,
java_api_version=2,
py_api_version=2):
native.filegroup(name=name + "_proto_srcs",
srcs=srcs,
testonly=testonly,)
cc_proto_library(name=name,
srcs=srcs,
deps=deps,
cc_libs = ["@protobuf_archive//:protobuf"],
protoc="@protobuf_archive//:protoc",
default_runtime="@protobuf_archive//:protobuf",
testonly=testonly,
visibility=visibility,)
def tf_proto_library_py(name, srcs=[], deps=[], visibility=None, testonly=0):
py_proto_library(name=name,
srcs=srcs,
srcs_version = "PY2AND3",
deps=deps,
default_runtime="@protobuf_archive//:protobuf_python",
protoc="@protobuf_archive//:protoc",
visibility=visibility,
testonly=testonly,)
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate a library for that file.
def tf_gen_op_libs(op_lib_names):
# Make library out of each op so it can also be used to generate wrappers
# for various languages.
for n in op_lib_names:
native.cc_library(name=n + "_op_lib",
copts=tf_copts(),
srcs=["ops/" + n + ".cc"],
deps=(["@org_tensorflow//tensorflow/core:framework"]),
visibility=["//visibility:public"],
alwayslink=1,
linkstatic=1,)
# Invoke this rule in .../tensorflow/python to build the wrapper library.
def tf_gen_op_wrapper_py(name, out=None, hidden=[], visibility=None, deps=[],
require_shape_functions=False):
# Construct a cc_binary containing the specified ops.
tool_name = "gen_" + name + "_py_wrappers_cc"
if not deps:
deps = ["//tensorflow/core:" + name + "_op_lib"]
native.cc_binary(
name = tool_name,
linkopts = ["-lm"],
copts = tf_copts(),
linkstatic = 1, # Faster to link this one-time-use binary dynamically
deps = (["@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/python:python_op_gen_main"] + deps),
)
# Invoke the previous cc_binary to generate a python file.
if not out:
out = "ops/gen_" + name + ".py"
native.genrule(
name=name + "_pygenrule",
outs=[out],
tools=[tool_name],
cmd=("$(location " + tool_name + ") " + ",".join(hidden)
+ " " + ("1" if require_shape_functions else "0") + " > $@"))
# Make a py_library out of the generated python file.
native.py_library(name=name,
srcs=[out],
srcs_version="PY2AND3",
visibility=visibility,
deps=[
"@org_tensorflow//tensorflow/python:framework_for_generated_wrappers",
],)
...@@ -19,6 +19,7 @@ limitations under the License. ...@@ -19,6 +19,7 @@ limitations under the License.
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h"
...@@ -52,8 +53,9 @@ void TermFrequencyMap::Clear() { ...@@ -52,8 +53,9 @@ void TermFrequencyMap::Clear() {
term_data_.clear(); term_data_.clear();
} }
void TermFrequencyMap::Load(const string &filename, int min_frequency, tensorflow::Status TermFrequencyMap::TryLoad(const string &filename,
int max_num_terms) { int min_frequency,
int max_num_terms) {
Clear(); Clear();
// If max_num_terms is non-positive, replace it with INT_MAX. // If max_num_terms is non-positive, replace it with INT_MAX.
...@@ -61,46 +63,83 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency, ...@@ -61,46 +63,83 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency,
// Read the first line (total # of terms in the mapping). // Read the first line (total # of terms in the mapping).
std::unique_ptr<tensorflow::RandomAccessFile> file; std::unique_ptr<tensorflow::RandomAccessFile> file;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(filename, &file)); TF_RETURN_IF_ERROR(
tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */ static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
tensorflow::io::RandomAccessInputStream stream(file.get()); tensorflow::io::RandomAccessInputStream stream(file.get());
tensorflow::io::BufferedInputStream buffer(&stream, kInputBufferSize); tensorflow::io::BufferedInputStream buffer(&stream, kInputBufferSize);
string line; string line;
TF_CHECK_OK(buffer.ReadLine(&line)); TF_RETURN_IF_ERROR(buffer.ReadLine(&line));
int32 total = -1; int32 total = -1;
CHECK(utils::ParseInt32(line.c_str(), &total)) if (!utils::ParseInt32(line.c_str(), &total)) {
<< "Unable to parse from " << filename; return tensorflow::errors::InvalidArgument(
CHECK_GE(total, 0); filename, ":0: Unable to parse term map size");
}
if (total < 0) {
return tensorflow::errors::InvalidArgument(
filename, ":0: Invalid term map size: ", total);
}
// Read the mapping. // Read the mapping.
int64 last_frequency = -1; int64 last_frequency = -1;
for (int i = 0; i < total && i < max_num_terms; ++i) { for (int i = 0; i < total && i < max_num_terms; ++i) {
TF_CHECK_OK(buffer.ReadLine(&line)); TF_RETURN_IF_ERROR(buffer.ReadLine(&line));
static LazyRE2 re = {"(.*) (\\d*)"};
string term; string term;
int64 frequency = 0; int64 frequency = 0;
CHECK(RE2::FullMatch(line, "(.*) (\\d*)", &term, &frequency)); if (!RE2::FullMatch(line, *re, &term, &frequency)) {
CHECK(!term.empty()); return tensorflow::errors::InvalidArgument(
CHECK_GT(frequency, 0); filename, ":", i + 1,
": Couldn't split term and frequency in line: ", line);
}
if (term.empty()) {
return tensorflow::errors::InvalidArgument(filename, ":", i + 1,
": Invalid empty term");
}
if (frequency <= 0) {
return tensorflow::errors::InvalidArgument(
filename, ":", i + 1, ": Invalid frequency: term=", term,
" frequency=", frequency);
}
// Check frequency sorting (descending order). // Check frequency sorting (descending order).
if (i > 0) CHECK_GE(last_frequency, frequency); if (i > 0 && last_frequency < frequency) {
return tensorflow::errors::InvalidArgument(
filename, ":", i + 1,
": Non-descending frequencies: current=", frequency,
" previous=", last_frequency);
}
last_frequency = frequency; last_frequency = frequency;
// Ignore low-frequency items. // Ignore low-frequency items.
if (frequency < min_frequency) continue; if (frequency < min_frequency) continue;
// Check uniqueness of the mapped terms. // Check uniqueness of the mapped terms.
CHECK(term_index_.find(term) == term_index_.end()) if (term_index_.find(term) != term_index_.end()) {
<< "File " << filename << " has duplicate term: " << term; return tensorflow::errors::InvalidArgument(filename, ":", i + 1,
": Duplicate term: ", term);
}
// Assign the next available index. // Assign the next available index.
const int index = term_index_.size(); const int index = term_index_.size();
term_index_[term] = index; term_index_[term] = index;
term_data_.push_back(std::pair<string, int64>(term, frequency)); term_data_.push_back(std::pair<string, int64>(term, frequency));
} }
CHECK_EQ(term_index_.size(), term_data_.size());
if (term_index_.size() != term_data_.size()) {
return tensorflow::errors::Internal(
"Unexpected size mismatch between term index (", term_index_.size(),
") and term data (", term_data_.size(), ")");
}
LOG(INFO) << "Loaded " << term_index_.size() << " terms from " << filename LOG(INFO) << "Loaded " << term_index_.size() << " terms from " << filename
<< "."; << ".";
return tensorflow::Status::OK();
}
void TermFrequencyMap::Load(const string &filename, int min_frequency,
int max_num_terms) {
TF_CHECK_OK(TryLoad(filename, min_frequency, max_num_terms));
} }
struct TermFrequencyMap::SortByFrequencyThenTerm { struct TermFrequencyMap::SortByFrequencyThenTerm {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment