Unverified Commit 80178fc6 authored by Mark Omernick's avatar Mark Omernick Committed by GitHub
Browse files

Merge pull request #4153 from terryykoo/master

Export @195097388.
parents a84e1ef9 edea2b67
......@@ -250,6 +250,7 @@ class GenericFeatureFunction {
string GetParameter(const string &name) const;
int GetIntParameter(const string &name, int default_value) const;
bool GetBoolParameter(const string &name, bool default_value) const;
double GetFloatParameter(const string &name, double default_value) const;
// Returns the FML function description for the feature function, i.e. the
// name and parameters without the nested features.
......
......@@ -108,6 +108,10 @@ class FMLParser {
string item_text_;
};
// Returns the |function| or |extractor| descriptor as an FML string.
string AsFML(const FeatureFunctionDescriptor &function);
string AsFML(const FeatureExtractorDescriptor &extractor);
} // namespace syntaxnet
#endif // SYNTAXNET_FML_PARSER_H_
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/fml_parser.h"
#include <string>
#include <vector>
#include "syntaxnet/base.h"
#include "syntaxnet/feature_extractor.pb.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace {
// Returns the list of lines in the |text|. Also strips trailing whitespace
// from each line, since the FML generator sometimes appends trailing spaces.
std::vector<string> LinesOf(const string &text) {
std::vector<string> lines = tensorflow::str_util::Split(
text, "\n", tensorflow::str_util::SkipEmpty());
for (string &line : lines) {
tensorflow::str_util::StripTrailingWhitespace(&line);
}
return lines;
}
// Tests that a single function can be round-trip converted from FML to
// descriptor protos and back to FML.
TEST(FMLParserTest, RoundTripSingleFunction) {
FeatureExtractorDescriptor extractor;
FMLParser().Parse("offset(1).input.token.word(min-freq=10)", &extractor);
EXPECT_EQ(LinesOf(AsFML(extractor)),
LinesOf("offset(1).input.token.word(min-freq=\"10\")"));
// Also check each individual feature function.
EXPECT_EQ(AsFML(extractor.feature(0)),
"offset(1).input.token.word(min-freq=\"10\")");
EXPECT_EQ(AsFML(extractor.feature(0).feature(0)),
"input.token.word(min-freq=\"10\")");
EXPECT_EQ(AsFML(extractor.feature(0).feature(0).feature(0)),
"token.word(min-freq=\"10\")");
EXPECT_EQ(AsFML(extractor.feature(0).feature(0).feature(0).feature(0)),
"word(min-freq=\"10\")");
}
// Tests that a set of functions can be round-trip converted from FML to
// descriptor protos and back to FML.
TEST(FMLParserTest, RoundTripMultipleFunctions) {
FeatureExtractorDescriptor extractor;
FMLParser().Parse(R"(offset(1).word(max-num-terms=987)
input { tag(outside=false) label }
pairs { stack.tag input.tag input.child(-1).label })",
&extractor);
// Note that AsFML() adds quotes to all feature option values.
EXPECT_EQ(LinesOf(AsFML(extractor)),
LinesOf("offset(1).word(max-num-terms=\"987\")\n"
"input { tag(outside=\"false\") label }\n"
"pairs { stack.tag input.tag input.child(-1).label }"));
}
} // namespace
} // namespace syntaxnet
......@@ -22,6 +22,7 @@ import syntaxnet.load_parser_ops
from tensorflow.python.ops import control_flow_ops as cf
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as tf_saver
from syntaxnet.ops import gen_parser_ops
......@@ -572,5 +573,6 @@ class GreedyParser(object):
for key in variables_to_save.keys():
if not key.endswith('avg_var'):
del variables_to_save[key]
self.saver = tf.train.Saver(variables_to_save)
self.saver = tf.train.Saver(
variables_to_save, builder=tf_saver.BaseSaverBuilder())
return self.saver
......@@ -20,33 +20,26 @@
import os.path
import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from syntaxnet import graph_builder
from syntaxnet import sparse_pb2
from syntaxnet import test_flags
from syntaxnet.ops import gen_parser_ops
FLAGS = tf.app.flags.FLAGS
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
class GraphBuilderTest(test_util.TensorFlowTestCase):
class GraphBuilderTest(tf.test.TestCase):
def setUp(self):
# Creates a task context with the correct testing paths.
initial_task_context = os.path.join(FLAGS.test_srcdir,
initial_task_context = os.path.join(test_flags.source_root(),
'syntaxnet/'
'testdata/context.pbtxt')
self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
self._task_context = os.path.join(test_flags.temp_dir(), 'context.pbtxt')
with open(initial_task_context, 'r') as fin:
with open(self._task_context, 'w') as fout:
fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir)
.replace('OUTPATH', FLAGS.test_tmpdir))
fout.write(fin.read().replace('SRCDIR', test_flags.source_root())
.replace('OUTPATH', test_flags.temp_dir()))
# Creates necessary term maps.
with self.test_session() as sess:
......@@ -320,4 +313,4 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
if __name__ == '__main__':
googletest.main()
tf.test.main()
......@@ -23,16 +23,13 @@ import tensorflow as tf
import syntaxnet.load_parser_ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from syntaxnet import sentence_pb2
from syntaxnet import task_spec_pb2
from syntaxnet import test_flags
from syntaxnet.ops import gen_parser_ops
FLAGS = tf.app.flags.FLAGS
CONLL_DOC1 = u'''1 बात _ n NN _ _ _ _ _
2 गलत _ adj JJ _ _ _ _ _
3 हो _ v VM _ _ _ _ _
......@@ -75,15 +72,11 @@ CHAR_NGRAMS = u'''^ अ ^ अभ ^ आ ^ आन ^ इ ^ इस $ ^ क ^
COMMENTS = u'# Line with fake comments.'
class LexiconBuilderTest(test_util.TensorFlowTestCase):
class LexiconBuilderTest(tf.test.TestCase):
def setUp(self):
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
self.corpus_file = os.path.join(FLAGS.test_tmpdir, 'documents.conll')
self.context_file = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
self.corpus_file = os.path.join(test_flags.temp_dir(), 'documents.conll')
self.context_file = os.path.join(test_flags.temp_dir(), 'context.pbtxt')
def AddInput(self, name, file_pattern, record_format, context):
inp = context.input.add()
......@@ -106,7 +99,8 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
'category-map', 'label-map', 'prefix-table',
'suffix-table', 'tag-to-category', 'char-map',
'char-ngram-map'):
self.AddInput(name, os.path.join(FLAGS.test_tmpdir, name), '', context)
self.AddInput(name, os.path.join(test_flags.temp_dir(), name), '',
context)
logging.info('Writing context to: %s', self.context_file)
with open(self.context_file, 'w') as f:
f.write(str(context))
......@@ -140,7 +134,7 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
self.assertTrue(last)
def ValidateTagToCategoryMap(self):
with open(os.path.join(FLAGS.test_tmpdir, 'tag-to-category'), 'r') as f:
with open(os.path.join(test_flags.temp_dir(), 'tag-to-category'), 'r') as f:
entries = [line.strip().split('\t') for line in f.readlines()]
for tag, category in entries:
self.assertIn(tag, TAGS)
......@@ -148,7 +142,7 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
def LoadMap(self, map_name):
loaded_map = {}
with open(os.path.join(FLAGS.test_tmpdir, map_name), 'r') as f:
with open(os.path.join(test_flags.temp_dir(), map_name), 'r') as f:
for line in f:
entries = line.strip().split(' ')
if len(entries) >= 2:
......@@ -237,4 +231,4 @@ class LexiconBuilderTest(test_util.TensorFlowTestCase):
if __name__ == '__main__':
googletest.main()
tf.test.main()
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/ops/shape_helpers.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace syntaxnet {
......@@ -29,6 +31,14 @@ REGISTER_OP("GoldParseReader")
.Attr("corpus_name: string='documents'")
.Attr("arg_prefix: string='brain_parser'")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
int feature_size;
TF_RETURN_IF_ERROR(context->GetAttr("feature_size", &feature_size));
for (int i = 0; i < feature_size; ++i) MatrixOutputShape(i, context);
ScalarOutputShape(feature_size, context);
VectorOutputShape(feature_size + 1, context);
return tensorflow::Status::OK();
})
.Doc(R"doc(
Reads sentences, parses them, and returns (gold action, feature) pairs.
......@@ -55,6 +65,15 @@ REGISTER_OP("DecodedParseReader")
.Attr("corpus_name: string='documents'")
.Attr("arg_prefix: string='brain_parser'")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
int feature_size;
TF_RETURN_IF_ERROR(context->GetAttr("feature_size", &feature_size));
for (int i = 0; i < feature_size; ++i) MatrixOutputShape(i, context);
ScalarOutputShape(feature_size, context);
context->set_output(feature_size + 1, context->Vector(2));
VectorOutputShape(feature_size + 2, context);
return MatrixInputShape(0, context);
})
.Doc(R"doc(
Reads sentences and parses them taking parsing transitions based on the
input transition scores.
......@@ -85,6 +104,14 @@ REGISTER_OP("BeamParseReader")
.Attr("continue_until_all_final: bool=false")
.Attr("always_start_new_sentences: bool=false")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
int feature_size;
TF_RETURN_IF_ERROR(context->GetAttr("feature_size", &feature_size));
for (int i = 0; i < feature_size; ++i) MatrixOutputShape(i, context);
ScalarOutputShape(feature_size, context);
ScalarOutputShape(feature_size + 1, context);
return tensorflow::Status::OK();
})
.Doc(R"doc(
Reads sentences and creates a beam parser.
......@@ -112,6 +139,15 @@ REGISTER_OP("BeamParser")
.Output("alive: bool")
.Attr("feature_size: int")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
int feature_size;
TF_RETURN_IF_ERROR(context->GetAttr("feature_size", &feature_size));
for (int i = 0; i < feature_size; ++i) MatrixOutputShape(i, context);
ScalarOutputShape(feature_size, context);
VectorOutputShape(feature_size + 1, context);
TF_RETURN_IF_ERROR(ScalarInputShape(0, context));
return MatrixInputShape(1, context);
})
.Doc(R"doc(
Updates the beam parser based on scores in the input transition scores.
......@@ -131,6 +167,13 @@ REGISTER_OP("BeamParserOutput")
.Output("gold_slot: int32")
.Output("path_scores: float")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
context->set_output(0, context->Matrix(2, context->UnknownDim()));
context->set_output(1, context->Matrix(2, context->UnknownDim()));
VectorOutputShape(2, context);
VectorOutputShape(3, context);
return ScalarInputShape(0, context);
})
.Doc(R"doc(
Converts the current state of the beam parser into a set of indices into
the scoring matrices that lead there.
......@@ -152,6 +195,11 @@ REGISTER_OP("BeamEvalOutput")
.Output("eval_metrics: int32")
.Output("documents: string")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
context->set_output(0, context->Vector(2));
VectorOutputShape(1, context);
return ScalarInputShape(0, context);
})
.Doc(R"doc(
Computes eval metrics for the best paths in the input beams.
......@@ -192,6 +240,13 @@ REGISTER_OP("FeatureSize")
.Output("embedding_dims: int32")
.Output("num_actions: int32")
.Attr("arg_prefix: string='brain_parser'")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
VectorOutputShape(1, context);
VectorOutputShape(2, context);
ScalarOutputShape(3, context);
return tensorflow::Status::OK();
})
.Doc(R"doc(
An op that returns the number and domain sizes of parser features.
......@@ -210,6 +265,10 @@ REGISTER_OP("FeatureVocab")
.Attr("arg_prefix: string='brain_parser'")
.Attr("embedding_name: string='words'")
.Output("vocab: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
return tensorflow::Status::OK();
})
.Doc(R"doc(
Returns the vocabulary of the parser features for a particular named channel.
For "words" this would would be the entire vocabulary, plus any special tokens
......@@ -227,6 +286,12 @@ REGISTER_OP("UnpackSyntaxNetSparseFeatures")
.Output("indices: int32")
.Output("ids: int64")
.Output("weights: float")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
VectorOutputShape(1, context);
VectorOutputShape(2, context);
return VectorInputShape(0, context);
})
.Doc(R"doc(
Converts a vector of strings with SparseFeatures to tensors.
......@@ -249,11 +314,16 @@ REGISTER_OP("WordEmbeddingInitializer")
.Attr("vectors: string")
.Attr("task_context: string = ''")
.Attr("vocabulary: string = ''")
.Attr("override_num_embeddings: int = -1")
.Attr("cache_vectors_locally: bool = true")
.Attr("num_special_embeddings: int = 3")
.Attr("embedding_init: float = 1.0")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
MatrixOutputShape(0, context);
return tensorflow::Status::OK();
})
.Doc(R"doc(
Reads word embeddings from an sstable of dist_belief.TokenEmbedding protos for
every word specified in a text vocabulary file.
......@@ -264,6 +334,10 @@ task_context: file path at which to read the task context, for its "word-map"
input. Exactly one of `task_context` or `vocabulary` must be specified.
vocabulary: path to vocabulary file, which contains one unique word per line, in
order. Exactly one of `task_context` or `vocabulary` must be specified.
override_num_embeddings: Number of rows in the returned embedding matrix. If
override_num_embeddings is larger than 0, then the returned embedding matrix
has override_num_embeddings_ rows. Otherwise, the number of rows of the
returned embedding matrix is |vocabulary| + num_special_embeddings.
cache_vectors_locally: Whether to cache the vectors file to a local temp file
before parsing it. This greatly reduces initialization time when the vectors
are stored remotely, but requires that "/tmp" has sufficient space.
......@@ -286,6 +360,11 @@ REGISTER_OP("DocumentSource")
.Attr("corpus_name: string='documents'")
.Attr("batch_size: int")
.SetIsStateful()
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
ScalarOutputShape(1, context);
return tensorflow::Status::OK();
})
.Doc(R"doc(
Reads documents from documents_path and outputs them.
......@@ -301,6 +380,9 @@ REGISTER_OP("DocumentSink")
.Attr("task_context: string=''")
.Attr("task_context_str: string=''")
.Attr("corpus_name: string='documents'")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
return VectorInputShape(0, context);
})
.Doc(R"doc(
Write documents to documents_path.
......@@ -312,6 +394,10 @@ task_context_str: a task context in text format, used if task_context is empty.
REGISTER_OP("SegmenterTrainingDataConstructor")
.Input("documents: string")
.Output("char_doc: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
return VectorInputShape(0, context);
})
.Doc(R"doc(
Constructs segmentation training data from documents with gold segmentation.
......@@ -322,6 +408,10 @@ char_doc: a vector of documents as serialized protos.
REGISTER_OP("CharTokenGenerator")
.Input("documents: string")
.Output("char_doc: string")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
return VectorInputShape(0, context);
})
.Doc(R"doc(
Converts token field of the input documents such that each token in the
output doc is a utf-8 character from that doc's text.
......@@ -337,6 +427,10 @@ REGISTER_OP("WellFormedFilter")
.Attr("task_context_str: string=''")
.Attr("corpus_name: string='documents'")
.Attr("keep_malformed_documents: bool = False")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
return VectorInputShape(0, context);
})
.Doc(R"doc(
Removes sentences with malformed parse trees, i.e. they contain cycles.
......@@ -353,6 +447,10 @@ REGISTER_OP("ProjectivizeFilter")
.Attr("task_context_str: string=''")
.Attr("corpus_name: string='documents'")
.Attr("discard_non_projective: bool = False")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
VectorOutputShape(0, context);
return VectorInputShape(0, context);
})
.Doc(R"doc(
Modifies input parse trees to make them projective.
......
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Shape inference functions for SyntaxNet ops.
#ifndef SYNTAXNET_OPS_SHAPE_HELPERS_H_
#define SYNTAXNET_OPS_SHAPE_HELPERS_H_
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
// Returns OK if the |input_index|'th input is a tensor of the |rank| with
// unknown dimensions.
inline tensorflow::Status TensorInputShape(
int input_index, int rank,
tensorflow::shape_inference::InferenceContext *context) {
tensorflow::shape_inference::ShapeHandle unused;
return context->WithRank(context->input(input_index), rank, &unused);
}
// Returns OK if the |input_index|'th input is a scalar.
inline tensorflow::Status ScalarInputShape(
int input_index, tensorflow::shape_inference::InferenceContext *context) {
return TensorInputShape(input_index, 0, context);
}
// Returns OK if the |input_index|'th input is a vector of unknown dimension.
inline tensorflow::Status VectorInputShape(
int input_index, tensorflow::shape_inference::InferenceContext *context) {
return TensorInputShape(input_index, 1, context);
}
// Returns OK if the |input_index|'th input is a matrix of unknown dimensions.
inline tensorflow::Status MatrixInputShape(
int input_index, tensorflow::shape_inference::InferenceContext *context) {
return TensorInputShape(input_index, 2, context);
}
// Sets the |output_index|'th output to a scalar.
inline void ScalarOutputShape(
int output_index, tensorflow::shape_inference::InferenceContext *context) {
context->set_output(output_index, context->Scalar());
}
// Sets the |output_index|'th output to a vector of unknown dimension.
inline void VectorOutputShape(
int output_index, tensorflow::shape_inference::InferenceContext *context) {
context->set_output(output_index, context->UnknownShapeOfRank(1));
}
// Sets the |output_index|'th output to a matrix of unknown dimensions.
inline void MatrixOutputShape(
int output_index, tensorflow::shape_inference::InferenceContext *context) {
context->set_output(output_index, context->UnknownShapeOfRank(2));
}
} // namespace syntaxnet
#endif // SYNTAXNET_OPS_SHAPE_HELPERS_H_
......@@ -19,6 +19,8 @@
import os
import os.path
import time
from absl import app
from absl import flags
import tempfile
import tensorflow as tf
......@@ -33,7 +35,6 @@ from syntaxnet import structured_graph_builder
from syntaxnet.ops import gen_parser_ops
from syntaxnet import task_spec_pb2
flags = tf.app.flags
FLAGS = flags.FLAGS
......@@ -158,4 +159,4 @@ def main(unused_argv):
if __name__ == '__main__':
tf.app.run()
app.run(main)
......@@ -331,24 +331,6 @@ class LastActionFeatureFunction : public ParserFeatureFunction {
REGISTER_PARSER_FEATURE_FUNCTION("last-action", LastActionFeatureFunction);
class Constant : public ParserFeatureFunction {
public:
void Init(TaskContext *context) override {
value_ = this->GetIntParameter("value", 0);
this->set_feature_type(new NumericFeatureType(this->name(), value_ + 1));
}
// Returns the constant's value.
FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
const FeatureVector *result) const override {
return value_;
}
private:
int value_ = 0;
};
REGISTER_PARSER_FEATURE_FUNCTION("constant", Constant);
// Register the generic parser features.
typedef GenericFeatures<ParserState> GenericParserFeature;
REGISTER_SYNTAXNET_GENERIC_FEATURES(GenericParserFeature);
......
......@@ -20,6 +20,8 @@
import os
import os.path
import time
from absl import app
from absl import flags
import tensorflow as tf
from tensorflow.python.platform import gfile
......@@ -32,7 +34,6 @@ from syntaxnet import structured_graph_builder
from syntaxnet.ops import gen_parser_ops
from syntaxnet import task_spec_pb2
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('tf_master', '',
......@@ -299,4 +300,4 @@ def main(unused_argv):
if __name__ == '__main__':
tf.app.run()
app.run(main)
......@@ -453,6 +453,8 @@ class WordEmbeddingInitializer : public OpKernel {
&cache_vectors_locally_));
OP_REQUIRES_OK(context, context->GetAttr("num_special_embeddings",
&num_special_embeddings_));
OP_REQUIRES_OK(context, context->GetAttr("override_num_embeddings",
&override_num_embeddings_));
OP_REQUIRES_OK(context,
context->GetAttr("embedding_init", &embedding_init_));
......@@ -569,7 +571,13 @@ class WordEmbeddingInitializer : public OpKernel {
const std::unordered_map<string, int64> &vocabulary,
const TokenEmbedding &embedding, OpKernelContext *context,
Tensor **embedding_matrix) const {
const int rows = vocabulary.size() + num_special_embeddings_;
const int rows = override_num_embeddings_ > 0 ? override_num_embeddings_ :
(vocabulary.size() + num_special_embeddings_);
if (rows < vocabulary.size()) {
return InvalidArgument(
"Embedding matrix row number ", rows,
" is less than vocabulary size ", vocabulary.size());
}
const int columns = embedding.vector().values_size();
TF_RETURN_IF_ERROR(context->allocate_output(0, TensorShape({rows, columns}),
embedding_matrix));
......@@ -637,6 +645,11 @@ class WordEmbeddingInitializer : public OpKernel {
// Number of special embeddings to allocate.
int num_special_embeddings_ = 3;
// If override_num_embeddings_ is larger than zero, then the returned
// embedding matrix has override_num_embeddings_ of rows. Otherwise, the
// number of rows equals to |vocabulary| + num_special_embeddigs_.
int override_num_embeddings_ = -1;
// Seed for random initialization.
uint64 seed_ = 0;
......
......@@ -20,35 +20,27 @@ import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from syntaxnet import dictionary_pb2
from syntaxnet import graph_builder
from syntaxnet import sparse_pb2
from syntaxnet import test_flags
from syntaxnet.ops import gen_parser_ops
FLAGS = tf.app.flags.FLAGS
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
class ParsingReaderOpsTest(tf.test.TestCase):
def setUp(self):
# Creates a task context with the correct testing paths.
initial_task_context = os.path.join(FLAGS.test_srcdir,
initial_task_context = os.path.join(test_flags.source_root(),
'syntaxnet/'
'testdata/context.pbtxt')
self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
self._task_context = os.path.join(test_flags.temp_dir(), 'context.pbtxt')
with open(initial_task_context, 'r') as fin:
with open(self._task_context, 'w') as fout:
fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir)
.replace('OUTPATH', FLAGS.test_tmpdir))
fout.write(fin.read().replace('SRCDIR', test_flags.source_root())
.replace('OUTPATH', test_flags.temp_dir()))
# Creates necessary term maps.
with self.test_session() as sess:
......@@ -175,7 +167,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
def testWordEmbeddingInitializer(self):
# Provide embeddings for the first three words in the word map.
records_path = os.path.join(FLAGS.test_tmpdir, 'records1')
records_path = os.path.join(test_flags.temp_dir(), 'records1')
writer = tf.python_io.TFRecordWriter(records_path)
writer.write(self._token_embedding('.', [1, 2]))
writer.write(self._token_embedding(',', [3, 4]))
......@@ -193,7 +185,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
embeddings[:3,])
def testWordEmbeddingInitializerRepeatability(self):
records_path = os.path.join(FLAGS.test_tmpdir, 'records2')
records_path = os.path.join(test_flags.temp_dir(), 'records2')
writer = tf.python_io.TFRecordWriter(records_path)
writer.write(self._token_embedding('.', [1, 2, 3])) # 3 dims
del writer
......@@ -234,7 +226,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
vocabulary='/dev/null').eval()
def testWordEmbeddingInitializerVocabularyFile(self):
records_path = os.path.join(FLAGS.test_tmpdir, 'records3')
records_path = os.path.join(test_flags.temp_dir(), 'records3')
writer = tf.python_io.TFRecordWriter(records_path)
writer.write(self._token_embedding('a', [1, 2, 3]))
writer.write(self._token_embedding('b', [2, 3, 4]))
......@@ -243,7 +235,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
writer.write(self._token_embedding('e', [5, 6, 7]))
del writer
vocabulary_path = os.path.join(FLAGS.test_tmpdir, 'vocabulary3')
vocabulary_path = os.path.join(test_flags.temp_dir(), 'vocabulary3')
with open(vocabulary_path, 'w') as vocabulary_file:
vocabulary_file.write('a\nc\ne\nx\n') # 'x' not in pretrained embeddings
......@@ -271,8 +263,50 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
[5.0 / norm_e, 6.0 / norm_e, 7.0 / norm_e]],
embeddings[:3].eval())
def testWordEmbeddingInitializerPresetRowNumber(self):
records_path = os.path.join(test_flags.temp_dir(), 'records3')
writer = tf.python_io.TFRecordWriter(records_path)
writer.write(self._token_embedding('a', [1, 2, 3]))
writer.write(self._token_embedding('b', [2, 3, 4]))
writer.write(self._token_embedding('c', [3, 4, 5]))
writer.write(self._token_embedding('d', [4, 5, 6]))
writer.write(self._token_embedding('e', [5, 6, 7]))
del writer
vocabulary_path = os.path.join(test_flags.temp_dir(), 'vocabulary3')
with open(vocabulary_path, 'w') as vocabulary_file:
vocabulary_file.write('a\nc\ne\nx\n') # 'x' not in pretrained embeddings
# Enumerate a variety of configurations.
for cache_vectors_locally in [False, True]:
for num_special_embeddings in [None, 1, 2, 5]: # None = use default of 3
for override_num_embeddings in [-1, 8, 10]:
with self.test_session():
embeddings = gen_parser_ops.word_embedding_initializer(
vectors=records_path,
vocabulary=vocabulary_path,
override_num_embeddings=override_num_embeddings,
cache_vectors_locally=cache_vectors_locally,
num_special_embeddings=num_special_embeddings)
# Expect 4 embeddings from the vocabulary plus special embeddings.
expected_num_embeddings = 4 + (num_special_embeddings or 3)
if override_num_embeddings > 0:
expected_num_embeddings = override_num_embeddings
self.assertAllEqual([expected_num_embeddings, 3],
tf.shape(embeddings).eval())
# The first 3 embeddings should be pretrained.
norm_a = (1.0 + 4.0 + 9.0)**0.5
norm_c = (9.0 + 16.0 + 25.0)**0.5
norm_e = (25.0 + 36.0 + 49.0)**0.5
self.assertAllClose([[1.0 / norm_a, 2.0 / norm_a, 3.0 / norm_a], [
3.0 / norm_c, 4.0 / norm_c, 5.0 / norm_c
], [5.0 / norm_e, 6.0 / norm_e, 7.0 / norm_e]],
embeddings[:3].eval())
def testWordEmbeddingInitializerVocabularyFileWithDuplicates(self):
records_path = os.path.join(FLAGS.test_tmpdir, 'records4')
records_path = os.path.join(test_flags.temp_dir(), 'records4')
writer = tf.python_io.TFRecordWriter(records_path)
writer.write(self._token_embedding('a', [1, 2, 3]))
writer.write(self._token_embedding('b', [2, 3, 4]))
......@@ -281,7 +315,7 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
writer.write(self._token_embedding('e', [5, 6, 7]))
del writer
vocabulary_path = os.path.join(FLAGS.test_tmpdir, 'vocabulary4')
vocabulary_path = os.path.join(test_flags.temp_dir(), 'vocabulary4')
with open(vocabulary_path, 'w') as vocabulary_file:
vocabulary_file.write('a\nc\ne\nx\ny\nx') # 'x' duplicated
......@@ -292,4 +326,4 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
if __name__ == '__main__':
googletest.main()
tf.test.main()
......@@ -15,6 +15,12 @@ limitations under the License.
#include "syntaxnet/registry.h"
#include <set>
#include <string>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
// Global list of all component registries.
......@@ -25,4 +31,35 @@ void RegistryMetadata::Register(RegistryMetadata *registry) {
global_registry_list = registry;
}
string ComponentMetadata::DebugString() const {
return tensorflow::strings::StrCat("Registered '", name_, "' as class ",
class_name_, " at ", file_, ":", line_);
}
tensorflow::Status RegistryMetadata::Validate() {
static const tensorflow::Status *const status =
new tensorflow::Status(ValidateImpl());
return *status;
}
tensorflow::Status RegistryMetadata::ValidateImpl() {
// Iterates over the registries for each type.
for (RegistryMetadata *registry = global_registry_list; registry != nullptr;
registry = static_cast<RegistryMetadata *>(registry->link())) {
std::set<string> names;
// Searches for duplicate names within each component registry.
for (ComponentMetadata *component = *(registry->components_);
component != nullptr; component = component->link()) {
if (!names.insert(component->name()).second) {
return tensorflow::errors::InvalidArgument(
"Multiple classes named '", component->name(),
"' have been registered as ", registry->name(), ": ",
component->DebugString());
}
}
}
return tensorflow::Status::OK();
}
} // namespace syntaxnet
......@@ -54,10 +54,13 @@ limitations under the License.
#define SYNTAXNET_REGISTRY_H_
#include <string.h>
#include <memory>
#include <string>
#include <vector>
#include "syntaxnet/utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
......@@ -75,6 +78,9 @@ class ComponentMetadata {
// Returns component name.
const char *name() const { return name_; }
// Returns a human-readable description of this.
string DebugString() const;
// Metadata objects can be linked in a list.
ComponentMetadata *link() const { return link_; }
void set_link(ComponentMetadata *link) { link_ = link; }
......@@ -107,7 +113,16 @@ class RegistryMetadata : public ComponentMetadata {
// Registers a component registry in the master registry.
static void Register(RegistryMetadata *registry);
// Validates the registry; returns non-OK if there are duplicate component
// names of the same type. Situations where this can happen include accidental
// class name collisions, and linking in two different multiarch versions
// of the same component. Repeated calls uses the original result.
static tensorflow::Status Validate();
private:
// Implementation for validating the registry.
static tensorflow::Status ValidateImpl();
// Location of list of components in registry.
ComponentMetadata **components_;
};
......@@ -157,14 +172,21 @@ struct ComponentRegistry {
T *object_;
};
// Finds registrar for named component in registry.
const Registrar *GetComponent(const char *type) const {
// Finds registrar for named component in registry, returning null if not
// found.
const Registrar *GetComponentOrNull(const char *type) const {
Registrar *r = components;
while (r != nullptr && strcmp(type, r->type()) != 0) r = r->next();
if (r == nullptr) {
return r;
}
// Finds registrar for named component in registry, raising errors on failure.
const Registrar *GetComponent(const char *type) const {
const Registrar *result = GetComponentOrNull(type);
if (result == nullptr) {
LOG(FATAL) << "Unknown " << name << " component: '" << type << "'.";
}
return r;
return result;
}
// Finds a named component in the registry.
......@@ -196,7 +218,24 @@ class RegisterableClass {
typedef ComponentRegistry<Factory> Registry;
// Creates a new component instance.
static T *Create(const string &type) { return registry()->Lookup(type)(); }
static T *Create(const string &type) {
TF_CHECK_OK(syntaxnet::RegistryMetadata::Validate());
return registry()->Lookup(type)();
}
static tensorflow::Status CreateOrError(const string &type,
std::unique_ptr<T> *result) {
TF_RETURN_IF_ERROR(syntaxnet::RegistryMetadata::Validate());
const typename Registry::Registrar *component =
registry()->GetComponentOrNull(type.c_str());
if (component == nullptr) {
return tensorflow::errors::NotFound("Unknown ", registry()->name, ": ",
type);
} else {
result->reset(component->object()());
return tensorflow::Status::OK();
}
}
// Returns registry for class.
static Registry *registry() { return &registry_; }
......
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "syntaxnet/registry.h"
#include <memory>
#include "dragnn/core/test/generic.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
class ThingDoer : public RegisterableClass<ThingDoer> {};
DECLARE_SYNTAXNET_CLASS_REGISTRY("Thing doer", ThingDoer);
REGISTER_SYNTAXNET_CLASS_REGISTRY("Thing doer", ThingDoer);
class Foo : public ThingDoer {};
class Bar : public ThingDoer {};
class Bar2 : public ThingDoer {};
REGISTER_SYNTAXNET_CLASS_COMPONENT(ThingDoer, "foo", Foo);
REGISTER_SYNTAXNET_CLASS_COMPONENT(ThingDoer, "bar", Bar);
#if DRAGNN_REGISTRY_TEST_WITH_DUPLICATE
REGISTER_SYNTAXNET_CLASS_COMPONENT(ThingDoer, "bar", Bar2); // bad
constexpr char kDuplicateError[] =
"Multiple classes named 'bar' have been registered as Thing doer";
#endif
namespace {
#if !DRAGNN_REGISTRY_TEST_WITH_DUPLICATE
// Tests that CreateOrError() is successful for a properly registered component.
TEST(RegistryTest, CreateOrErrorSuccess) {
std::unique_ptr<ThingDoer> object;
TF_ASSERT_OK(ThingDoer::CreateOrError("foo", &object));
ASSERT_NE(object, nullptr);
}
#else
// Tests that CreateOrError() fails if the registry is misconfigured.
TEST(RegistryTest, CreateOrErrorFailure) {
std::unique_ptr<ThingDoer> object;
EXPECT_THAT(ThingDoer::CreateOrError("bar", &object),
test::IsErrorWithSubstr(kDuplicateError));
ASSERT_EQ(object, nullptr);
// Any call to Create has the same error.
EXPECT_THAT(ThingDoer::CreateOrError("foo", &object),
test::IsErrorWithSubstr(kDuplicateError));
}
// Tests that Create() dies if the registry is misconfigured.
TEST(RegistryTest, CreateFailure) {
EXPECT_DEATH(ThingDoer::Create("bar"), kDuplicateError);
}
#endif
// Tests that CreateOrError() returns error if the component is unknown.
TEST(RegistryTest, CreateOrErrorUnknown) {
std::unique_ptr<ThingDoer> object;
EXPECT_FALSE(ThingDoer::CreateOrError("unknown", &object).ok());
}
// Tests that Validate() returns OK only when the registry is fine.
TEST(RegistryTest, Validate) {
#if DRAGNN_REGISTRY_TEST_WITH_DUPLICATE
EXPECT_THAT(RegistryMetadata::Validate(),
test::IsErrorWithSubstr(kDuplicateError));
#else
TF_EXPECT_OK(RegistryMetadata::Validate());
#endif
}
} // namespace
} // namespace syntaxnet
......@@ -39,7 +39,7 @@ class SharedStore {
static const T *Get(const string &name,
Args &&...args); // NOLINT(build/c++11)
// Like Get(), but creates the object with "closure->Run()". If the closure
// Like Get(), but creates the object with "(*closure)()". If the closure
// returns null, we store a null in the SharedStore, but note that Release()
// cannot be used to remove it. This is because Release() finds the object
// by associative lookup, and there may be more than one null value, so we
......
......@@ -115,9 +115,8 @@ class StructuredGraphBuilder(graph_builder.GreedyParser):
return tf.logical_and(args[1] < max_steps, tf.reduce_any(args[3]))
step = tf.constant(0, tf.int32, [])
scores_array = tensor_array_ops.TensorArray(dtype=tf.float32,
size=0,
dynamic_size=True)
scores_array = tensor_array_ops.TensorArray(
dtype=tf.float32, size=0, infer_shape=False, dynamic_size=True)
alive = tf.constant(True, tf.bool, [batch_size])
alive_steps = tf.constant(0, tf.int32, [batch_size])
t = tf.while_loop(
......
......@@ -12,99 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
load("@protobuf_archive//:protobuf.bzl", "cc_proto_library")
load("@protobuf_archive//:protobuf.bzl", "py_proto_library")
def if_cuda(if_true, if_false = []):
"""Shorthand for select()'ing on whether we're building with CUDA."""
return select({
"@local_config_cuda//cuda:using_nvcc": if_true,
"@local_config_cuda//cuda:using_clang": if_true,
"//conditions:default": if_false
})
def tf_copts():
return (["-fno-exceptions", "-DEIGEN_AVOID_STL_ARRAY",] +
if_cuda(["-DGOOGLE_CUDA=1"]) +
select({"@org_tensorflow//tensorflow:darwin": [],
"//conditions:default": ["-pthread"]}))
def tf_proto_library(name, srcs=[], has_services=False,
deps=[], visibility=None, testonly=0,
cc_api_version=2, go_api_version=2,
java_api_version=2,
py_api_version=2):
native.filegroup(name=name + "_proto_srcs",
srcs=srcs,
testonly=testonly,)
cc_proto_library(name=name,
srcs=srcs,
deps=deps,
cc_libs = ["@protobuf_archive//:protobuf"],
protoc="@protobuf_archive//:protoc",
default_runtime="@protobuf_archive//:protobuf",
testonly=testonly,
visibility=visibility,)
def tf_proto_library_py(name, srcs=[], deps=[], visibility=None, testonly=0):
py_proto_library(name=name,
srcs=srcs,
srcs_version = "PY2AND3",
deps=deps,
default_runtime="@protobuf_archive//:protobuf_python",
protoc="@protobuf_archive//:protoc",
visibility=visibility,
testonly=testonly,)
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate a library for that file.
def tf_gen_op_libs(op_lib_names):
# Make library out of each op so it can also be used to generate wrappers
# for various languages.
for n in op_lib_names:
native.cc_library(name=n + "_op_lib",
copts=tf_copts(),
srcs=["ops/" + n + ".cc"],
deps=(["@org_tensorflow//tensorflow/core:framework"]),
visibility=["//visibility:public"],
alwayslink=1,
linkstatic=1,)
# Invoke this rule in .../tensorflow/python to build the wrapper library.
def tf_gen_op_wrapper_py(name, out=None, hidden=[], visibility=None, deps=[],
require_shape_functions=False):
# Construct a cc_binary containing the specified ops.
tool_name = "gen_" + name + "_py_wrappers_cc"
if not deps:
deps = ["//tensorflow/core:" + name + "_op_lib"]
native.cc_binary(
name = tool_name,
linkopts = ["-lm"],
copts = tf_copts(),
linkstatic = 1, # Faster to link this one-time-use binary dynamically
deps = (["@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/python:python_op_gen_main"] + deps),
)
# Invoke the previous cc_binary to generate a python file.
if not out:
out = "ops/gen_" + name + ".py"
native.genrule(
name=name + "_pygenrule",
outs=[out],
tools=[tool_name],
cmd=("$(location " + tool_name + ") " + ",".join(hidden)
+ " " + ("1" if require_shape_functions else "0") + " > $@"))
# Make a py_library out of the generated python file.
native.py_library(name=name,
srcs=[out],
srcs_version="PY2AND3",
visibility=visibility,
deps=[
"@org_tensorflow//tensorflow/python:framework_for_generated_wrappers",
],)
"""Build rules for Syntaxnet."""
load(
"@org_tensorflow//tensorflow/core:platform/default/build_config.bzl",
orig_tf_proto_library_cc = "tf_proto_library_cc",
)
load(
"@org_tensorflow//tensorflow/core:platform/default/build_config.bzl",
orig_tf_proto_library_py = "tf_proto_library_py",
)
# For some reason, tf_proto_library_cc() isn't obeying the default_visibility
# directive at the top of the build file. So just set it to public (which it is
# anyway).
def tf_proto_library_cc(name, visibility=[], **kwargs):
visibility = visibility if visibility else ["//visibility:public"]
return orig_tf_proto_library_cc(name, visibility=visibility, **kwargs)
def tf_proto_library_py(name, visibility=[], **kwargs):
visibility = visibility if visibility else ["//visibility:public"]
return orig_tf_proto_library_py(name, visibility=visibility, **kwargs)
......@@ -19,6 +19,7 @@ limitations under the License.
#include <algorithm>
#include <limits>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
......@@ -52,8 +53,9 @@ void TermFrequencyMap::Clear() {
term_data_.clear();
}
void TermFrequencyMap::Load(const string &filename, int min_frequency,
int max_num_terms) {
tensorflow::Status TermFrequencyMap::TryLoad(const string &filename,
int min_frequency,
int max_num_terms) {
Clear();
// If max_num_terms is non-positive, replace it with INT_MAX.
......@@ -61,46 +63,83 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency,
// Read the first line (total # of terms in the mapping).
std::unique_ptr<tensorflow::RandomAccessFile> file;
TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
TF_RETURN_IF_ERROR(
tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
tensorflow::io::RandomAccessInputStream stream(file.get());
tensorflow::io::BufferedInputStream buffer(&stream, kInputBufferSize);
string line;
TF_CHECK_OK(buffer.ReadLine(&line));
TF_RETURN_IF_ERROR(buffer.ReadLine(&line));
int32 total = -1;
CHECK(utils::ParseInt32(line.c_str(), &total))
<< "Unable to parse from " << filename;
CHECK_GE(total, 0);
if (!utils::ParseInt32(line.c_str(), &total)) {
return tensorflow::errors::InvalidArgument(
filename, ":0: Unable to parse term map size");
}
if (total < 0) {
return tensorflow::errors::InvalidArgument(
filename, ":0: Invalid term map size: ", total);
}
// Read the mapping.
int64 last_frequency = -1;
for (int i = 0; i < total && i < max_num_terms; ++i) {
TF_CHECK_OK(buffer.ReadLine(&line));
TF_RETURN_IF_ERROR(buffer.ReadLine(&line));
static LazyRE2 re = {"(.*) (\\d*)"};
string term;
int64 frequency = 0;
CHECK(RE2::FullMatch(line, "(.*) (\\d*)", &term, &frequency));
CHECK(!term.empty());
CHECK_GT(frequency, 0);
if (!RE2::FullMatch(line, *re, &term, &frequency)) {
return tensorflow::errors::InvalidArgument(
filename, ":", i + 1,
": Couldn't split term and frequency in line: ", line);
}
if (term.empty()) {
return tensorflow::errors::InvalidArgument(filename, ":", i + 1,
": Invalid empty term");
}
if (frequency <= 0) {
return tensorflow::errors::InvalidArgument(
filename, ":", i + 1, ": Invalid frequency: term=", term,
" frequency=", frequency);
}
// Check frequency sorting (descending order).
if (i > 0) CHECK_GE(last_frequency, frequency);
if (i > 0 && last_frequency < frequency) {
return tensorflow::errors::InvalidArgument(
filename, ":", i + 1,
": Non-descending frequencies: current=", frequency,
" previous=", last_frequency);
}
last_frequency = frequency;
// Ignore low-frequency items.
if (frequency < min_frequency) continue;
// Check uniqueness of the mapped terms.
CHECK(term_index_.find(term) == term_index_.end())
<< "File " << filename << " has duplicate term: " << term;
if (term_index_.find(term) != term_index_.end()) {
return tensorflow::errors::InvalidArgument(filename, ":", i + 1,
": Duplicate term: ", term);
}
// Assign the next available index.
const int index = term_index_.size();
term_index_[term] = index;
term_data_.push_back(std::pair<string, int64>(term, frequency));
}
CHECK_EQ(term_index_.size(), term_data_.size());
if (term_index_.size() != term_data_.size()) {
return tensorflow::errors::Internal(
"Unexpected size mismatch between term index (", term_index_.size(),
") and term data (", term_data_.size(), ")");
}
LOG(INFO) << "Loaded " << term_index_.size() << " terms from " << filename
<< ".";
return tensorflow::Status::OK();
}
void TermFrequencyMap::Load(const string &filename, int min_frequency,
int max_num_terms) {
TF_CHECK_OK(TryLoad(filename, min_frequency, max_num_terms));
}
struct TermFrequencyMap::SortByFrequencyThenTerm {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment