Unverified Commit 80178fc6 authored by Mark Omernick's avatar Mark Omernick Committed by GitHub
Browse files

Merge pull request #4153 from terryykoo/master

Export @195097388.
parents a84e1ef9 edea2b67
...@@ -24,6 +24,8 @@ limitations under the License. ...@@ -24,6 +24,8 @@ limitations under the License.
#include <vector> #include <vector>
#include "syntaxnet/utils.h" #include "syntaxnet/utils.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet { namespace syntaxnet {
...@@ -51,6 +53,9 @@ class TermFrequencyMap { ...@@ -51,6 +53,9 @@ class TermFrequencyMap {
// Returns the term associated with the given index. // Returns the term associated with the given index.
const string &GetTerm(int index) const { return term_data_[index].first; } const string &GetTerm(int index) const { return term_data_[index].first; }
// Returns the frequency associated with the given index.
int64 GetFrequency(int index) const { return term_data_[index].second; }
// Increases the frequency of the given term by 1, creating a new entry if // Increases the frequency of the given term by 1, creating a new entry if
// necessary, and returns the index of the term. // necessary, and returns the index of the term.
int Increment(const string &term); int Increment(const string &term);
...@@ -59,14 +64,19 @@ class TermFrequencyMap { ...@@ -59,14 +64,19 @@ class TermFrequencyMap {
void Clear(); void Clear();
// Loads a frequency mapping from the given file, which must have been created // Loads a frequency mapping from the given file, which must have been created
// by an earlier call to Save(). After loading, the term indices are // by an earlier call to Save(). On error, returns non-OK.
// guaranteed to be ordered in descending order of frequency (breaking ties //
// arbitrarily). However, any new terms inserted after loading do not // After loading, the term indices are guaranteed to be ordered in descending
// maintain this sorting invariant. // order of frequency (breaking ties arbitrarily). However, any new terms
// inserted after loading do not maintain this sorting invariant.
// //
// Only loads terms with frequency >= min_frequency. If max_num_terms <= 0, // Only loads terms with frequency >= min_frequency. If max_num_terms <= 0,
// then all qualifying terms are loaded; otherwise, max_num_terms terms with // then all qualifying terms are loaded; otherwise, max_num_terms terms with
// maximal frequency are loaded (breaking ties arbitrarily). // maximal frequency are loaded (breaking ties arbitrarily).
tensorflow::Status TryLoad(const string &filename, int min_frequency,
int max_num_terms);
// Like TryLoad(), but fails on error.
void Load(const string &filename, int min_frequency, int max_num_terms); void Load(const string &filename, int min_frequency, int max_num_terms);
// Saves a frequency mapping to the given file. // Saves a frequency mapping to the given file.
...@@ -74,7 +84,8 @@ class TermFrequencyMap { ...@@ -74,7 +84,8 @@ class TermFrequencyMap {
private: private:
// Hashtable for term-to-index mapping. // Hashtable for term-to-index mapping.
typedef std::unordered_map<string, int> TermIndex; using TermIndex = std::unordered_map<string, int>;
// Sorting functor for term data. // Sorting functor for term data.
struct SortByFrequencyThenTerm; struct SortByFrequencyThenTerm;
......
/* Copyright 2017 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "syntaxnet/term_frequency_map.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace {
// Matches an error status whose message matches |substr|.
MATCHER(IsError, string(negation ? "isn't" : "is") + " an error Status") {
return !arg.ok();
}
// Matches an error status whose message matches |substr|.
MATCHER_P(IsErrorWithSubstr, substr,
string(negation ? "isn't" : "is") +
" an error Status whose message matches the substring '" +
::testing::PrintToString(substr) + "'") {
return !arg.ok() && arg.error_message().find(substr) != string::npos;
}
// Writes the |content| to a temporary file and returns its path.
string AsTempFile(const string &content) {
static int counter = 0;
const string basename = tensorflow::strings::StrCat("temp_", counter++);
const string path =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), basename);
TF_CHECK_OK(
tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, content));
return path;
}
// Tests that TermFrequencyMap::TryLoad() fails on an invalid path.
TEST(TermFrequencyMapTest, TryLoadInvalidPath) {
const string kInvalidPath = "/some/invalid/path";
TermFrequencyMap term_map;
EXPECT_THAT(term_map.TryLoad(kInvalidPath, 0, 0), IsError());
}
// Tests that TermFrequencyMap::TryLoad() fails on an empty file.
TEST(TermFrequencyMapTest, TryLoadEmptyFile) {
const string path = AsTempFile("");
TermFrequencyMap term_map;
EXPECT_THAT(term_map.TryLoad(path, 0, 0), IsError());
}
// Tests that TermFrequencyMap::TryLoad() fails if the term count in the first
// line is not parsable as an integer.
TEST(TermFrequencyMapTest, TryLoadFileWithMalformedCount) {
const string path = AsTempFile("asdf");
TermFrequencyMap term_map;
EXPECT_THAT(term_map.TryLoad(path, 0, 0),
IsErrorWithSubstr(tensorflow::strings::StrCat(
path, ":0: Unable to parse term map size")));
}
// Tests that TermFrequencyMap::TryLoad() fails if the term count in the first
// line is negative.
TEST(TermFrequencyMapTest, TryLoadFileWithNegativeCount) {
const string path = AsTempFile("-1");
TermFrequencyMap term_map;
EXPECT_THAT(term_map.TryLoad(path, 0, 0),
IsErrorWithSubstr(tensorflow::strings::StrCat(
path, ":0: Invalid term map size: -1")));
}
// Tests that TermFrequencyMap::TryLoad() is OK if there are no terms.
TEST(TermFrequencyMapTest, TryLoadFileWithNoTerms) {
const string path = AsTempFile("0");
TermFrequencyMap term_map;
TF_ASSERT_OK(term_map.TryLoad(path, 0, 0));
EXPECT_EQ(term_map.Size(), 0);
}
// Tests that TermFrequencyMap::TryLoad() fails if there is a malformed line.
TEST(TermFrequencyMapTest, TryLoadFileWithMalformedLine) {
const string path = AsTempFile(
"2\n"
"valid term with spaces 1\n"
"bad term\n");
TermFrequencyMap term_map;
EXPECT_THAT(
term_map.TryLoad(path, 0, 0),
IsErrorWithSubstr(tensorflow::strings::StrCat(
path, ":2: Couldn't split term and frequency in line: bad term")));
}
// Tests that TermFrequencyMap::TryLoad() fails if there is an empty term.
TEST(TermFrequencyMapTest, TryLoadFileWithEmptyTerm) {
const string path = AsTempFile(
"2\n"
" 1\n"
"some_term 1\n");
TermFrequencyMap term_map;
EXPECT_THAT(term_map.TryLoad(path, 0, 0),
IsErrorWithSubstr(
tensorflow::strings::StrCat(path, ":1: Invalid empty term")));
}
// Tests that TermFrequencyMap::TryLoad() fails if there is a term with zero
// frequency.
TEST(TermFrequencyMapTest, TryLoadFileWithZeroFrequency) {
const string path = AsTempFile(
"2\n"
"good_term 1\n"
"bad_term 0\n");
TermFrequencyMap term_map;
EXPECT_THAT(term_map.TryLoad(path, 0, 0),
IsErrorWithSubstr(tensorflow::strings::StrCat(
path, ":2: Invalid frequency: term=bad_term frequency=0")));
}
// Tests that TermFrequencyMap::TryLoad() fails if terms are not in descending
// order of frequency.
TEST(TermFrequencyMapTest, TryLoadFileWithOutOfOrderTerms) {
const string path = AsTempFile(
"2\n"
"good_term 1\n"
"bad_term 2\n");
TermFrequencyMap term_map;
EXPECT_THAT(
term_map.TryLoad(path, 0, 0),
IsErrorWithSubstr(tensorflow::strings::StrCat(
path, ":2: Non-descending frequencies: current=2 previous=1")));
}
// Tests that TermFrequencyMap::TryLoad() fails if there are duplicate terms.
TEST(TermFrequencyMapTest, TryLoadFileWithDuplicateTerms) {
const string path = AsTempFile(
"2\n"
"duplicate 1\n"
"duplicate 1\n");
TermFrequencyMap term_map;
EXPECT_THAT(term_map.TryLoad(path, 0, 0),
IsErrorWithSubstr(tensorflow::strings::StrCat(
path, ":2: Duplicate term: duplicate")));
}
// Tests that TermFrequencyMap contains the specified terms and frequencies.
TEST(TermFrequencyMapTest, LoadAndCheckContents) {
const string path = AsTempFile(
"3\n"
"foo 100\n"
"bar 10\n"
"baz 1\n");
TermFrequencyMap term_map;
TF_ASSERT_OK(term_map.TryLoad(path, 0, 0));
EXPECT_EQ(term_map.Size(), 3);
EXPECT_EQ(term_map.GetTerm(0), "foo");
EXPECT_EQ(term_map.GetTerm(1), "bar");
EXPECT_EQ(term_map.GetTerm(2), "baz");
EXPECT_EQ(term_map.GetFrequency(0), 100);
EXPECT_EQ(term_map.GetFrequency(1), 10);
EXPECT_EQ(term_map.GetFrequency(2), 1);
}
} // namespace
} // namespace syntaxnet
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Accessors for test flags, with fallback for missing flags."""
from absl import flags
import tensorflow as tf
FLAGS = flags.FLAGS
def temp_dir():
"""Returns a temporary directory for tests."""
return getattr(FLAGS, 'test_tmpdir', tf.test.get_temp_dir())
def source_root():
"""Returns the path to the root of the source directory tree for tests."""
return getattr(FLAGS, 'test_srcdir', '')
...@@ -22,26 +22,19 @@ import tensorflow as tf ...@@ -22,26 +22,19 @@ import tensorflow as tf
import syntaxnet.load_parser_ops import syntaxnet.load_parser_ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from syntaxnet import sentence_pb2 from syntaxnet import sentence_pb2
from syntaxnet import task_spec_pb2 from syntaxnet import task_spec_pb2
from syntaxnet import test_flags
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
FLAGS = tf.app.flags.FLAGS
class TextFormatsTest(tf.test.TestCase):
class TextFormatsTest(test_util.TensorFlowTestCase):
def setUp(self): def setUp(self):
if not hasattr(FLAGS, 'test_srcdir'): self.corpus_file = os.path.join(test_flags.temp_dir(), 'documents.conll')
FLAGS.test_srcdir = '' self.context_file = os.path.join(test_flags.temp_dir(), 'context.pbtxt')
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
self.corpus_file = os.path.join(FLAGS.test_tmpdir, 'documents.conll')
self.context_file = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
def AddInput(self, name, file_pattern, record_format, context): def AddInput(self, name, file_pattern, record_format, context):
inp = context.input.add() inp = context.input.add()
...@@ -60,7 +53,8 @@ class TextFormatsTest(test_util.TensorFlowTestCase): ...@@ -60,7 +53,8 @@ class TextFormatsTest(test_util.TensorFlowTestCase):
for name in ('word-map', 'lcword-map', 'tag-map', 'category-map', for name in ('word-map', 'lcword-map', 'tag-map', 'category-map',
'label-map', 'prefix-table', 'suffix-table', 'label-map', 'prefix-table', 'suffix-table',
'tag-to-category'): 'tag-to-category'):
self.AddInput(name, os.path.join(FLAGS.test_tmpdir, name), '', context) self.AddInput(name, os.path.join(test_flags.temp_dir(), name), '',
context)
logging.info('Writing context to: %s', self.context_file) logging.info('Writing context to: %s', self.context_file)
with open(self.context_file, 'w') as f: with open(self.context_file, 'w') as f:
f.write(str(context)) f.write(str(context))
...@@ -254,4 +248,4 @@ token { ...@@ -254,4 +248,4 @@ token {
if __name__ == '__main__': if __name__ == '__main__':
googletest.main() tf.test.main()
Subproject commit c52cdc03a67ceae9ecc8c00025d3c60f54833e2d Subproject commit 8753e2ebde6c58b56675cc19ab7ff83072824a62
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment