"test/vscode:/vscode.git/clone" did not exist on "f774299be269a19217d8ad9d80f4ca2eb61de1d4"
Commit 4364390a authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by calberti
Browse files

Release DRAGNN bulk networks (#2785)

* Release DRAGNN bulk networks
parent 638fd759
......@@ -258,7 +258,7 @@ REGISTER_OP("WordEmbeddingInitializer")
Reads word embeddings from an sstable of dist_belief.TokenEmbedding protos for
every word specified in a text vocabulary file.
word_embeddings: a tensor containing word embeddings from the specified table.
word_embeddings: a tensor containing word embeddings from the specified sstable.
vectors: path to TF record file of word embedding vectors.
task_context: file path at which to read the task context, for its "word-map"
input. Exactly one of `task_context` or `vocabulary` must be specified.
......
......@@ -17,8 +17,10 @@ limitations under the License.
#include <string>
#include "syntaxnet/generic_features.h"
#include "syntaxnet/registry.h"
#include "syntaxnet/sentence_features.h"
#include "syntaxnet/whole_sentence_features.h"
#include "syntaxnet/workspace.h"
namespace syntaxnet {
......@@ -347,4 +349,8 @@ class Constant : public ParserFeatureFunction {
REGISTER_PARSER_FEATURE_FUNCTION("constant", Constant);
// Register the generic parser features.
typedef GenericFeatures<ParserState> GenericParserFeature;
REGISTER_SYNTAXNET_GENERIC_FEATURES(GenericParserFeature);
} // namespace syntaxnet
......@@ -146,6 +146,14 @@ class BasicParserSentenceFeatureFunction :
}
};
// Registry for the parser feature functions.
DECLARE_SYNTAXNET_CLASS_REGISTRY("parser feature function",
ParserFeatureFunction);
// Registry for the parser state + token index feature functions.
DECLARE_SYNTAXNET_CLASS_REGISTRY("parser+index feature function",
ParserIndexFeatureFunction);
} // namespace syntaxnet
#endif // SYNTAXNET_PARSER_FEATURES_H_
......@@ -84,7 +84,6 @@ class ParserFeatureFunctionTest : public ::testing::Test {
// Prepares a feature for computations.
string ExtractFeature(const string &feature_name) {
context_.mutable_spec()->mutable_input()->Clear();
context_.mutable_spec()->mutable_output()->Clear();
feature_extractor_.reset(new ParserFeatureExtractor());
feature_extractor_->Parse(feature_name);
feature_extractor_->Setup(&context_);
......@@ -152,4 +151,10 @@ TEST_F(ParserFeatureFunctionTest, GoldHeadFeatureFunction) {
EXPECT_EQ("1", ExtractFeature("input(7).gold-head"));
}
TEST_F(ParserFeatureFunctionTest, PairFeatureFunction) {
EXPECT_EQ("(1,PRP)", ExtractFeature("pair { input.gold-head input.tag }"));
EXPECT_EQ("(1,PRP,ROOT)",
ExtractFeature("triple { input.gold-head input.tag input.label }"));
}
} // namespace syntaxnet
......@@ -22,6 +22,8 @@ namespace syntaxnet {
// Transition system registry.
REGISTER_SYNTAXNET_CLASS_REGISTRY("transition system", ParserTransitionSystem);
constexpr int ParserTransitionSystem::kDynamicNumActions;
void ParserTransitionSystem::PerformAction(ParserAction action,
ParserState *state) const {
if (state->keep_history()) {
......
......@@ -74,6 +74,9 @@ class ParserTransitionState {
class ParserTransitionSystem
: public RegisterableClass<ParserTransitionSystem> {
public:
// Sentinel value that represents a dynamic action set.
static constexpr int kDynamicNumActions = -1;
// Construction and cleanup.
ParserTransitionSystem() {}
virtual ~ParserTransitionSystem() {}
......@@ -94,7 +97,8 @@ class ParserTransitionSystem
// Returns the number of action types.
virtual int NumActionTypes() const = 0;
// Returns the number of actions.
// Returns the number of actions, or |kDynamicNumActions| if the action set is
// dynamic (i.e., varies per instance).
virtual int NumActions(int num_labels) const = 0;
// Internally creates the set of outcomes (when transition systems support a
......@@ -196,6 +200,9 @@ class ParserTransitionSystem
#define REGISTER_TRANSITION_SYSTEM(type, component) \
REGISTER_SYNTAXNET_CLASS_COMPONENT(ParserTransitionSystem, type, component)
// Transition system registry.
DECLARE_SYNTAXNET_CLASS_REGISTRY("transition system", ParserTransitionSystem);
} // namespace syntaxnet
#endif // SYNTAXNET_PARSER_TRANSITIONS_H_
......@@ -66,6 +66,8 @@ class ProtoRecordReader {
CHECK(proto->ParseFromString(buffer));
return tensorflow::Status::OK();
} else {
CHECK_EQ(status.code(), tensorflow::error::OUT_OF_RANGE)
<< "Non-OK and non-out-of-range (EOF) status: " << status;
return status;
}
}
......
......@@ -15,7 +15,6 @@
"""Tests for reader_ops."""
# pylint: disable=no-name-in-module,unused-import,g-bad-import-order,maybe-no-member,no-member,g-importing-member
import os.path
import numpy as np
......@@ -30,6 +29,7 @@ from syntaxnet import graph_builder
from syntaxnet import sparse_pb2
from syntaxnet.ops import gen_parser_ops
FLAGS = tf.app.flags.FLAGS
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
......
......@@ -229,6 +229,10 @@ class RegisterableInstance {
classname::Registry RegisterableClass<classname>::registry_ = { \
type, #classname, __FILE__, __LINE__, NULL}
#define DECLARE_SYNTAXNET_CLASS_REGISTRY(type, classname) \
template <> \
classname::Registry RegisterableClass<classname>::registry_;
#define REGISTER_SYNTAXNET_INSTANCE_COMPONENT(base, type, component) \
static base::Registry::Registrar __##component##__##registrar( \
base::registry(), type, #component, __FILE__, __LINE__, new component)
......@@ -238,6 +242,10 @@ class RegisterableInstance {
classname::Registry RegisterableInstance<classname>::registry_ = { \
type, #classname, __FILE__, __LINE__, NULL}
#define DECLARE_SYNTAXNET_INSTANCE_REGISTRY(type, classname) \
template <> \
classname::Registry RegisterableInstance<classname>::registry_;
} // namespace syntaxnet
#endif // SYNTAXNET_REGISTRY_H_
......@@ -663,6 +663,10 @@ typedef FeatureExtractor<Sentence, int> SentenceExtractor;
#define REGISTER_SENTENCE_IDX_FEATURE(name, type) \
REGISTER_SYNTAXNET_FEATURE_FUNCTION(SentenceFeature, name, type)
// Registry for the Sentence + token index feature functions.
DECLARE_SYNTAXNET_CLASS_REGISTRY("sentence+index feature function",
SentenceFeature);
} // namespace syntaxnet
#endif // SYNTAXNET_SENTENCE_FEATURES_H_
......@@ -51,7 +51,6 @@ class SentenceFeaturesTest : public ::testing::Test {
// anything in info_ field into the LexiFuse repository.
virtual void PrepareFeature(const string &fml) {
context_.mutable_spec()->mutable_input()->Clear();
context_.mutable_spec()->mutable_output()->Clear();
extractor_.reset(new SentenceExtractor());
extractor_->Parse(fml);
extractor_->Setup(&context_);
......@@ -78,6 +77,7 @@ class SentenceFeaturesTest : public ::testing::Test {
FeatureVector result;
extractor_->ExtractFeatures(workspaces_, sentence_, index,
&result);
values.reserve(result.size());
for (int i = 0; i < result.size(); ++i) {
values.push_back(result.type(i)->GetFeatureValueName(result.value(i)));
}
......@@ -99,6 +99,7 @@ class SentenceFeaturesTest : public ::testing::Test {
void CheckVectorWorkspace(const VectorIntWorkspace &workspace,
std::vector<int> target) {
std::vector<int> src;
src.reserve(workspace.size());
for (int i = 0; i < workspace.size(); ++i) {
src.push_back(workspace.element(i));
}
......
......@@ -16,6 +16,6 @@
"""Imports the SyntaxNet ops and their C++ implementations."""
from syntaxnet.ops.gen_parser_ops import * # pylint: disable=wildcard-import
from syntaxnet.ops.gen_parser_ops import *
import syntaxnet.load_parser_ops
......@@ -35,39 +35,8 @@ message TaskInput {
}
}
// Task output descriptor.
message TaskOutput {
// Name of output resource.
required string name = 1;
// File format for output resource.
optional string file_format = 2;
// Record format for output resource.
optional string record_format = 3;
// Number of shards in output. If it is different from zero this output is
// sharded. If the number of shards is set to -1 this means that the output is
// sharded, but the number of shard is unknown. The files are then named
// 'base-*-of-*'.
optional int32 shards = 4 [default = 0];
// Base file name for output resource. If this is not set by the task
// component it is set to a default value by the workflow engine.
optional string file_base = 5;
// Optional extension added to the file name.
optional string file_extension = 6;
}
// A task specification is used for describing executing parameters.
message TaskSpec {
// Name of task.
optional string task_name = 1;
// Workflow task type.
optional string task_type = 2;
// Task parameters.
repeated group Parameter = 3 {
required string name = 4;
......@@ -77,6 +46,6 @@ message TaskSpec {
// Task inputs.
repeated TaskInput input = 6;
// Task outputs.
repeated TaskOutput output = 7;
reserved 1, 2, 7;
reserved "task_name", "task_type", "output";
}
......@@ -68,7 +68,8 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency,
string line;
TF_CHECK_OK(buffer.ReadLine(&line));
int32 total = -1;
CHECK(utils::ParseInt32(line.c_str(), &total));
CHECK(utils::ParseInt32(line.c_str(), &total))
<< "Unable to parse from " << filename;
CHECK_GE(total, 0);
// Read the mapping.
......
......@@ -157,7 +157,12 @@ class CoNLLSyntaxFormat : public DocumentFormat {
const int start = text.size();
const int end = start + word.size() - 1;
text.append(word);
add_space_to_text = fields[9] != "SpaceAfter=No";
// Determine whether a space should be added to sentence text.
std::vector<string> sub_fields = utils::Split(fields[9], '|');
auto no_space = [](const string &str) { return str == "SpaceAfter=No"; };
add_space_to_text =
!std::any_of(sub_fields.begin(), sub_fields.end(), no_space);
// Add token to sentence.
Token *token = sentence->add_token();
......@@ -329,28 +334,28 @@ REGISTER_SYNTAXNET_DOCUMENT_FORMAT("conll-sentence", CoNLLSyntaxFormat);
//
// Examples:
// To create a training example for sentence with raw text:
// That's a good point.
// "That's a good point."
// and the corresponding gold segmentation:
// That 's a good point .
// "That" "\'s" "a" "good" "point" "."
// Then the correct input is:
// That NO_SPACE
// 's SPACE
// a SPACE
// good SPACE
// point NO_SPACE
// . NO_SPACE
// "That\tNO_SPACE"
// "'s\tSPACE"
// "a\tSPACE"
// "good\tSPACE"
// "point\tNO_SPACE"
// ".\tNO_SPACE"
//
// Yet another example:
// To create a training example for sentence with raw text:
// 这是一个测试
// "这是一个测试"
// and the corresponding gold segmentation:
// 这 是 一 个 测试
// "这" "是" "一" "个" "测试"
// Then the correct input is:
// NO_SPACE
// NO_SPACE
// NO_SPACE
// NO_SPACE
// 测试 NO_SPACE
// "这\tNO_SPACE"
// "是\tNO_SPACE"
// "一\tNO_SPACE"
// "个\tNO_SPACE"
// "测试\tNO_SPACE"
class SegmentationTrainingDataFormat : public CoNLLSyntaxFormat {
public:
// Converts to segmentation training data by breaking those word in the input
......
......@@ -113,13 +113,13 @@ class TextFormatsTest(test_util.TensorFlowTestCase):
# This test sentence includes a multiword token and an empty node,
# both of which are to be ignored.
test_sentence = """
1-2 We've _
1 We we PRON PRP Case=Nom 3 nsubj _ SpaceAfter=No
2 've have AUX VBP Mood=Ind 3 aux _ _
3 moved move VERB VBN Tense=Past 0 root _ _
4 on on ADV RB _ 3 advmod _ SpaceAfter=No
4.1 ignored ignore VERB VBN Tense=Past 0 _ _ _
5 . . PUNCT . _ 3 punct _ _
1-2\tWe've\t_
1\tWe\twe\tPRON\tPRP\tCase=Nom\t3\tnsubj\t_\tSpaceAfter=No
2\t've\thave\tAUX\tVBP\tMood=Ind\t3\taux\t_\t_
3\tmoved\tmove\tVERB\tVBN\tTense=Past\t0\troot\t_\t_
4\ton\ton\tADV\tRB\t_\t3\tadvmod\t_\tSpaceAfter=No|foobar=baz
4.1\tignored\tignore\tVERB\tVBN\tTense=Past\t0\t_\t_\t_
5\t.\t.\tPUNCT\t.\t_\t3\tpunct\t_\t_
"""
# Prepare test sentence.
......@@ -191,13 +191,13 @@ token {
self.assertEqual(expected_ends, [t.end for t in sentence_doc.token])
def testSegmentationTrainingData(self):
doc1_lines = ['测试 NO_SPACE\n', '的 NO_SPACE\n', '句子 NO_SPACE']
doc1_lines = ['测试\tNO_SPACE\n', '的\tNO_SPACE\n', '句子\tNO_SPACE']
doc1_text = '测试的句子'
doc1_tokens = ['测', '试', '的', '句', '子']
doc1_break_levles = [1, 0, 1, 1, 0]
doc2_lines = [
'That NO_SPACE\n', '\'s SPACE\n', 'a SPACE\n', 'good SPACE\n',
'point NO_SPACE\n', '. NO_SPACE'
'That\tNO_SPACE\n', '\'s\tSPACE\n', 'a\tSPACE\n', 'good\tSPACE\n',
'point\tNO_SPACE\n', '.\tNO_SPACE'
]
doc2_text = 'That\'s a good point.'
doc2_tokens = [
......
......@@ -16,6 +16,14 @@ py_library(
srcs = ["check.py"],
)
py_library(
name = "resources",
srcs = ["resources.py"],
visibility = ["//visibility:public"],
deps = [
],
)
py_library(
name = "pyregistry_test_base",
testonly = 1,
......@@ -56,3 +64,15 @@ py_test(
"@org_tensorflow//tensorflow/core:protos_all_py",
],
)
py_test(
name = "resources_test",
srcs = ["resources_test.py"],
data = [
"//syntaxnet:testdata/hello.txt",
],
deps = [
":resources",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
"""A component registry, similar to nlp_saft::RegisteredClass<>.
Like nlp_saft::RegisteredClass<>, one does not need to explicitly import the
module containing each subclass. It is sufficient to add subclasses as build
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A component registry, similar to RegisterableClass<>.
Like RegisterableClass<>, one does not need to explicitly import the module
containing each subclass. It is sufficient to add subclasses as build
dependencies.
Unlike nlp_saft::RegisteredClass<>, which allows subclasses to be registered
under arbitrary names, subclasses must be looked up based on their type name.
This restriction allows the registry to dynamically import the module containing
the desired subclass.
Unlike RegisterableClass<>, which allows subclasses to be registered under
arbitrary names, subclasses must be looked up based on their type name. This
restriction allows the registry to dynamically import the module containing the
desired subclass.
Example usage:
......@@ -82,7 +97,7 @@ def _GetClass(name):
# Need at least "module.Class".
if len(elements) < 2:
logging.debug('Malformed type: "%s"', name)
logging.info('Malformed type: "%s"', name)
return None
module_path = '.'.join(elements[:-1])
class_name = elements[-1]
......@@ -91,20 +106,19 @@ def _GetClass(name):
try:
__import__(module_path)
except ImportError as e:
logging.debug('Unable to find module "%s": "%s"', module_path, e)
logging.info('Unable to find module "%s": "%s"', module_path, e)
return None
module = sys.modules[module_path]
# Look up the class.
if not hasattr(module, class_name):
logging.debug('Name "%s" not found in module: "%s"', class_name,
module_path)
logging.info('Name "%s" not found in module: "%s"', class_name, module_path)
return None
class_obj = getattr(module, class_name)
# Check that it is actually a class.
if not inspect.isclass(class_obj):
logging.debug('Name does not refer to a class: "%s"', name)
logging.info('Name does not refer to a class: "%s"', name)
return None
return class_obj
......@@ -125,8 +139,8 @@ def _Create(baseclass, subclass_name, *args, **kwargs):
if subclass is None:
return None # _GetClass() already logged an error
if not issubclass(subclass, baseclass):
logging.debug('Class "%s" is not a subclass of "%s"', subclass_name,
baseclass.__name__)
logging.info('Class "%s" is not a subclass of "%s"', subclass_name,
baseclass.__name__)
return None
return subclass(*args, **kwargs)
......@@ -135,13 +149,13 @@ def _ResolveAndCreate(baseclass, path, subclass_name, *args, **kwargs):
"""Resolves the name of a subclass and creates an instance of it.
The subclass is resolved with respect to a package path in an inside-out
manner. For example, if |path| is 'google3.foo.bar' and |subclass_name| is
manner. For example, if |path| is 'syntaxnet.foo.bar' and |subclass_name| is
'baz.ClassName', then attempts are made to create instances of the following
fully-qualified class names:
'google3.foo.bar.baz.ClassName'
'google3.foo.baz.ClassName'
'google3.baz.ClassName'
'syntaxnet.foo.bar.baz.ClassName'
'syntaxnet.foo.baz.ClassName'
'syntaxnet.baz.ClassName'
'baz.ClassName'
An instance corresponding to the first successful attempt is returned.
......@@ -163,9 +177,12 @@ def _ResolveAndCreate(baseclass, path, subclass_name, *args, **kwargs):
elements = path.split('.')
while True:
resolved_subclass_name = '.'.join(elements + [subclass_name])
logging.info('Attempting to instantiate "%s"', resolved_subclass_name)
subclass = _Create(baseclass, resolved_subclass_name, *args, **kwargs)
if subclass: return subclass # success
if not elements: break # no more paths to try
if subclass:
return subclass # success
if not elements:
break # no more paths to try
elements.pop() # try resolving against the next-outer path
raise ValueError(
'Failed to create subclass "%s" of base class %s using path %s' %
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for registry system.
This test uses two other modules:
......@@ -104,9 +119,11 @@ class RegistryTest(googletest.TestCase):
def testCannotResolveRelativeName(self):
"""Tests that Create fails if a relative path cannot be resolved."""
for name in [
'nlp.saft.opensource.syntaxnet.util.registry_test_base.Impl',
'saft.bad.registry_test_impl.Impl', 'missing.registry_test_impl.Impl',
'registry_test_impl.Bad', 'Impl'
'bad.syntaxnet.util.registry_test_base.Impl',
'syntaxnet.bad.registry_test_impl.Impl',
'missing.registry_test_impl.Impl',
'registry_test_impl.Bad',
'Impl'
]:
with self.assertRaisesRegexp(ValueError, 'Failed to create'):
registry_test_base.Base.Create(name, 'hello world')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment