Commit 4364390a authored by Ivan Bogatyy's avatar Ivan Bogatyy Committed by calberti
Browse files

Release DRAGNN bulk networks (#2785)

* Release DRAGNN bulk networks
parent 638fd759
...@@ -258,7 +258,7 @@ REGISTER_OP("WordEmbeddingInitializer") ...@@ -258,7 +258,7 @@ REGISTER_OP("WordEmbeddingInitializer")
Reads word embeddings from an sstable of dist_belief.TokenEmbedding protos for Reads word embeddings from an sstable of dist_belief.TokenEmbedding protos for
every word specified in a text vocabulary file. every word specified in a text vocabulary file.
word_embeddings: a tensor containing word embeddings from the specified table. word_embeddings: a tensor containing word embeddings from the specified sstable.
vectors: path to TF record file of word embedding vectors. vectors: path to TF record file of word embedding vectors.
task_context: file path at which to read the task context, for its "word-map" task_context: file path at which to read the task context, for its "word-map"
input. Exactly one of `task_context` or `vocabulary` must be specified. input. Exactly one of `task_context` or `vocabulary` must be specified.
......
...@@ -17,8 +17,10 @@ limitations under the License. ...@@ -17,8 +17,10 @@ limitations under the License.
#include <string> #include <string>
#include "syntaxnet/generic_features.h"
#include "syntaxnet/registry.h" #include "syntaxnet/registry.h"
#include "syntaxnet/sentence_features.h" #include "syntaxnet/sentence_features.h"
#include "syntaxnet/whole_sentence_features.h"
#include "syntaxnet/workspace.h" #include "syntaxnet/workspace.h"
namespace syntaxnet { namespace syntaxnet {
...@@ -347,4 +349,8 @@ class Constant : public ParserFeatureFunction { ...@@ -347,4 +349,8 @@ class Constant : public ParserFeatureFunction {
REGISTER_PARSER_FEATURE_FUNCTION("constant", Constant); REGISTER_PARSER_FEATURE_FUNCTION("constant", Constant);
// Register the generic parser features.
typedef GenericFeatures<ParserState> GenericParserFeature;
REGISTER_SYNTAXNET_GENERIC_FEATURES(GenericParserFeature);
} // namespace syntaxnet } // namespace syntaxnet
...@@ -146,6 +146,14 @@ class BasicParserSentenceFeatureFunction : ...@@ -146,6 +146,14 @@ class BasicParserSentenceFeatureFunction :
} }
}; };
// Registry for the parser feature functions.
DECLARE_SYNTAXNET_CLASS_REGISTRY("parser feature function",
ParserFeatureFunction);
// Registry for the parser state + token index feature functions.
DECLARE_SYNTAXNET_CLASS_REGISTRY("parser+index feature function",
ParserIndexFeatureFunction);
} // namespace syntaxnet } // namespace syntaxnet
#endif // SYNTAXNET_PARSER_FEATURES_H_ #endif // SYNTAXNET_PARSER_FEATURES_H_
...@@ -84,7 +84,6 @@ class ParserFeatureFunctionTest : public ::testing::Test { ...@@ -84,7 +84,6 @@ class ParserFeatureFunctionTest : public ::testing::Test {
// Prepares a feature for computations. // Prepares a feature for computations.
string ExtractFeature(const string &feature_name) { string ExtractFeature(const string &feature_name) {
context_.mutable_spec()->mutable_input()->Clear(); context_.mutable_spec()->mutable_input()->Clear();
context_.mutable_spec()->mutable_output()->Clear();
feature_extractor_.reset(new ParserFeatureExtractor()); feature_extractor_.reset(new ParserFeatureExtractor());
feature_extractor_->Parse(feature_name); feature_extractor_->Parse(feature_name);
feature_extractor_->Setup(&context_); feature_extractor_->Setup(&context_);
...@@ -152,4 +151,10 @@ TEST_F(ParserFeatureFunctionTest, GoldHeadFeatureFunction) { ...@@ -152,4 +151,10 @@ TEST_F(ParserFeatureFunctionTest, GoldHeadFeatureFunction) {
EXPECT_EQ("1", ExtractFeature("input(7).gold-head")); EXPECT_EQ("1", ExtractFeature("input(7).gold-head"));
} }
TEST_F(ParserFeatureFunctionTest, PairFeatureFunction) {
EXPECT_EQ("(1,PRP)", ExtractFeature("pair { input.gold-head input.tag }"));
EXPECT_EQ("(1,PRP,ROOT)",
ExtractFeature("triple { input.gold-head input.tag input.label }"));
}
} // namespace syntaxnet } // namespace syntaxnet
...@@ -22,6 +22,8 @@ namespace syntaxnet { ...@@ -22,6 +22,8 @@ namespace syntaxnet {
// Transition system registry. // Transition system registry.
REGISTER_SYNTAXNET_CLASS_REGISTRY("transition system", ParserTransitionSystem); REGISTER_SYNTAXNET_CLASS_REGISTRY("transition system", ParserTransitionSystem);
constexpr int ParserTransitionSystem::kDynamicNumActions;
void ParserTransitionSystem::PerformAction(ParserAction action, void ParserTransitionSystem::PerformAction(ParserAction action,
ParserState *state) const { ParserState *state) const {
if (state->keep_history()) { if (state->keep_history()) {
......
...@@ -74,6 +74,9 @@ class ParserTransitionState { ...@@ -74,6 +74,9 @@ class ParserTransitionState {
class ParserTransitionSystem class ParserTransitionSystem
: public RegisterableClass<ParserTransitionSystem> { : public RegisterableClass<ParserTransitionSystem> {
public: public:
// Sentinel value that represents a dynamic action set.
static constexpr int kDynamicNumActions = -1;
// Construction and cleanup. // Construction and cleanup.
ParserTransitionSystem() {} ParserTransitionSystem() {}
virtual ~ParserTransitionSystem() {} virtual ~ParserTransitionSystem() {}
...@@ -94,7 +97,8 @@ class ParserTransitionSystem ...@@ -94,7 +97,8 @@ class ParserTransitionSystem
// Returns the number of action types. // Returns the number of action types.
virtual int NumActionTypes() const = 0; virtual int NumActionTypes() const = 0;
// Returns the number of actions. // Returns the number of actions, or |kDynamicNumActions| if the action set is
// dynamic (i.e., varies per instance).
virtual int NumActions(int num_labels) const = 0; virtual int NumActions(int num_labels) const = 0;
// Internally creates the set of outcomes (when transition systems support a // Internally creates the set of outcomes (when transition systems support a
...@@ -196,6 +200,9 @@ class ParserTransitionSystem ...@@ -196,6 +200,9 @@ class ParserTransitionSystem
#define REGISTER_TRANSITION_SYSTEM(type, component) \ #define REGISTER_TRANSITION_SYSTEM(type, component) \
REGISTER_SYNTAXNET_CLASS_COMPONENT(ParserTransitionSystem, type, component) REGISTER_SYNTAXNET_CLASS_COMPONENT(ParserTransitionSystem, type, component)
// Transition system registry.
DECLARE_SYNTAXNET_CLASS_REGISTRY("transition system", ParserTransitionSystem);
} // namespace syntaxnet } // namespace syntaxnet
#endif // SYNTAXNET_PARSER_TRANSITIONS_H_ #endif // SYNTAXNET_PARSER_TRANSITIONS_H_
...@@ -66,6 +66,8 @@ class ProtoRecordReader { ...@@ -66,6 +66,8 @@ class ProtoRecordReader {
CHECK(proto->ParseFromString(buffer)); CHECK(proto->ParseFromString(buffer));
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} else { } else {
CHECK_EQ(status.code(), tensorflow::error::OUT_OF_RANGE)
<< "Non-OK and non-out-of-range (EOF) status: " << status;
return status; return status;
} }
} }
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
"""Tests for reader_ops.""" """Tests for reader_ops."""
# pylint: disable=no-name-in-module,unused-import,g-bad-import-order,maybe-no-member,no-member,g-importing-member
import os.path import os.path
import numpy as np import numpy as np
...@@ -30,6 +29,7 @@ from syntaxnet import graph_builder ...@@ -30,6 +29,7 @@ from syntaxnet import graph_builder
from syntaxnet import sparse_pb2 from syntaxnet import sparse_pb2
from syntaxnet.ops import gen_parser_ops from syntaxnet.ops import gen_parser_ops
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
if not hasattr(FLAGS, 'test_srcdir'): if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = '' FLAGS.test_srcdir = ''
......
...@@ -229,6 +229,10 @@ class RegisterableInstance { ...@@ -229,6 +229,10 @@ class RegisterableInstance {
classname::Registry RegisterableClass<classname>::registry_ = { \ classname::Registry RegisterableClass<classname>::registry_ = { \
type, #classname, __FILE__, __LINE__, NULL} type, #classname, __FILE__, __LINE__, NULL}
#define DECLARE_SYNTAXNET_CLASS_REGISTRY(type, classname) \
template <> \
classname::Registry RegisterableClass<classname>::registry_;
#define REGISTER_SYNTAXNET_INSTANCE_COMPONENT(base, type, component) \ #define REGISTER_SYNTAXNET_INSTANCE_COMPONENT(base, type, component) \
static base::Registry::Registrar __##component##__##registrar( \ static base::Registry::Registrar __##component##__##registrar( \
base::registry(), type, #component, __FILE__, __LINE__, new component) base::registry(), type, #component, __FILE__, __LINE__, new component)
...@@ -238,6 +242,10 @@ class RegisterableInstance { ...@@ -238,6 +242,10 @@ class RegisterableInstance {
classname::Registry RegisterableInstance<classname>::registry_ = { \ classname::Registry RegisterableInstance<classname>::registry_ = { \
type, #classname, __FILE__, __LINE__, NULL} type, #classname, __FILE__, __LINE__, NULL}
#define DECLARE_SYNTAXNET_INSTANCE_REGISTRY(type, classname) \
template <> \
classname::Registry RegisterableInstance<classname>::registry_;
} // namespace syntaxnet } // namespace syntaxnet
#endif // SYNTAXNET_REGISTRY_H_ #endif // SYNTAXNET_REGISTRY_H_
...@@ -663,6 +663,10 @@ typedef FeatureExtractor<Sentence, int> SentenceExtractor; ...@@ -663,6 +663,10 @@ typedef FeatureExtractor<Sentence, int> SentenceExtractor;
#define REGISTER_SENTENCE_IDX_FEATURE(name, type) \ #define REGISTER_SENTENCE_IDX_FEATURE(name, type) \
REGISTER_SYNTAXNET_FEATURE_FUNCTION(SentenceFeature, name, type) REGISTER_SYNTAXNET_FEATURE_FUNCTION(SentenceFeature, name, type)
// Registry for the Sentence + token index feature functions.
DECLARE_SYNTAXNET_CLASS_REGISTRY("sentence+index feature function",
SentenceFeature);
} // namespace syntaxnet } // namespace syntaxnet
#endif // SYNTAXNET_SENTENCE_FEATURES_H_ #endif // SYNTAXNET_SENTENCE_FEATURES_H_
...@@ -51,7 +51,6 @@ class SentenceFeaturesTest : public ::testing::Test { ...@@ -51,7 +51,6 @@ class SentenceFeaturesTest : public ::testing::Test {
// anything in info_ field into the LexiFuse repository. // anything in info_ field into the LexiFuse repository.
virtual void PrepareFeature(const string &fml) { virtual void PrepareFeature(const string &fml) {
context_.mutable_spec()->mutable_input()->Clear(); context_.mutable_spec()->mutable_input()->Clear();
context_.mutable_spec()->mutable_output()->Clear();
extractor_.reset(new SentenceExtractor()); extractor_.reset(new SentenceExtractor());
extractor_->Parse(fml); extractor_->Parse(fml);
extractor_->Setup(&context_); extractor_->Setup(&context_);
...@@ -78,6 +77,7 @@ class SentenceFeaturesTest : public ::testing::Test { ...@@ -78,6 +77,7 @@ class SentenceFeaturesTest : public ::testing::Test {
FeatureVector result; FeatureVector result;
extractor_->ExtractFeatures(workspaces_, sentence_, index, extractor_->ExtractFeatures(workspaces_, sentence_, index,
&result); &result);
values.reserve(result.size());
for (int i = 0; i < result.size(); ++i) { for (int i = 0; i < result.size(); ++i) {
values.push_back(result.type(i)->GetFeatureValueName(result.value(i))); values.push_back(result.type(i)->GetFeatureValueName(result.value(i)));
} }
...@@ -99,6 +99,7 @@ class SentenceFeaturesTest : public ::testing::Test { ...@@ -99,6 +99,7 @@ class SentenceFeaturesTest : public ::testing::Test {
void CheckVectorWorkspace(const VectorIntWorkspace &workspace, void CheckVectorWorkspace(const VectorIntWorkspace &workspace,
std::vector<int> target) { std::vector<int> target) {
std::vector<int> src; std::vector<int> src;
src.reserve(workspace.size());
for (int i = 0; i < workspace.size(); ++i) { for (int i = 0; i < workspace.size(); ++i) {
src.push_back(workspace.element(i)); src.push_back(workspace.element(i));
} }
......
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
"""Imports the SyntaxNet ops and their C++ implementations.""" """Imports the SyntaxNet ops and their C++ implementations."""
from syntaxnet.ops.gen_parser_ops import * # pylint: disable=wildcard-import from syntaxnet.ops.gen_parser_ops import *
import syntaxnet.load_parser_ops import syntaxnet.load_parser_ops
...@@ -35,39 +35,8 @@ message TaskInput { ...@@ -35,39 +35,8 @@ message TaskInput {
} }
} }
// Task output descriptor.
message TaskOutput {
// Name of output resource.
required string name = 1;
// File format for output resource.
optional string file_format = 2;
// Record format for output resource.
optional string record_format = 3;
// Number of shards in output. If it is different from zero this output is
// sharded. If the number of shards is set to -1 this means that the output is
// sharded, but the number of shard is unknown. The files are then named
// 'base-*-of-*'.
optional int32 shards = 4 [default = 0];
// Base file name for output resource. If this is not set by the task
// component it is set to a default value by the workflow engine.
optional string file_base = 5;
// Optional extension added to the file name.
optional string file_extension = 6;
}
// A task specification is used for describing executing parameters. // A task specification is used for describing executing parameters.
message TaskSpec { message TaskSpec {
// Name of task.
optional string task_name = 1;
// Workflow task type.
optional string task_type = 2;
// Task parameters. // Task parameters.
repeated group Parameter = 3 { repeated group Parameter = 3 {
required string name = 4; required string name = 4;
...@@ -77,6 +46,6 @@ message TaskSpec { ...@@ -77,6 +46,6 @@ message TaskSpec {
// Task inputs. // Task inputs.
repeated TaskInput input = 6; repeated TaskInput input = 6;
// Task outputs. reserved 1, 2, 7;
repeated TaskOutput output = 7; reserved "task_name", "task_type", "output";
} }
...@@ -68,7 +68,8 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency, ...@@ -68,7 +68,8 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency,
string line; string line;
TF_CHECK_OK(buffer.ReadLine(&line)); TF_CHECK_OK(buffer.ReadLine(&line));
int32 total = -1; int32 total = -1;
CHECK(utils::ParseInt32(line.c_str(), &total)); CHECK(utils::ParseInt32(line.c_str(), &total))
<< "Unable to parse from " << filename;
CHECK_GE(total, 0); CHECK_GE(total, 0);
// Read the mapping. // Read the mapping.
......
...@@ -157,7 +157,12 @@ class CoNLLSyntaxFormat : public DocumentFormat { ...@@ -157,7 +157,12 @@ class CoNLLSyntaxFormat : public DocumentFormat {
const int start = text.size(); const int start = text.size();
const int end = start + word.size() - 1; const int end = start + word.size() - 1;
text.append(word); text.append(word);
add_space_to_text = fields[9] != "SpaceAfter=No";
// Determine whether a space should be added to sentence text.
std::vector<string> sub_fields = utils::Split(fields[9], '|');
auto no_space = [](const string &str) { return str == "SpaceAfter=No"; };
add_space_to_text =
!std::any_of(sub_fields.begin(), sub_fields.end(), no_space);
// Add token to sentence. // Add token to sentence.
Token *token = sentence->add_token(); Token *token = sentence->add_token();
...@@ -329,28 +334,28 @@ REGISTER_SYNTAXNET_DOCUMENT_FORMAT("conll-sentence", CoNLLSyntaxFormat); ...@@ -329,28 +334,28 @@ REGISTER_SYNTAXNET_DOCUMENT_FORMAT("conll-sentence", CoNLLSyntaxFormat);
// //
// Examples: // Examples:
// To create a training example for sentence with raw text: // To create a training example for sentence with raw text:
// That's a good point. // "That's a good point."
// and the corresponding gold segmentation: // and the corresponding gold segmentation:
// That 's a good point . // "That" "\'s" "a" "good" "point" "."
// Then the correct input is: // Then the correct input is:
// That NO_SPACE // "That\tNO_SPACE"
// 's SPACE // "'s\tSPACE"
// a SPACE // "a\tSPACE"
// good SPACE // "good\tSPACE"
// point NO_SPACE // "point\tNO_SPACE"
// . NO_SPACE // ".\tNO_SPACE"
// //
// Yet another example: // Yet another example:
// To create a training example for sentence with raw text: // To create a training example for sentence with raw text:
// 这是一个测试 // "这是一个测试"
// and the corresponding gold segmentation: // and the corresponding gold segmentation:
// 这 是 一 个 测试 // "这" "是" "一" "个" "测试"
// Then the correct input is: // Then the correct input is:
// NO_SPACE // "这\tNO_SPACE"
// NO_SPACE // "是\tNO_SPACE"
// NO_SPACE // "一\tNO_SPACE"
// NO_SPACE // "个\tNO_SPACE"
// 测试 NO_SPACE // "测试\tNO_SPACE"
class SegmentationTrainingDataFormat : public CoNLLSyntaxFormat { class SegmentationTrainingDataFormat : public CoNLLSyntaxFormat {
public: public:
// Converts to segmentation training data by breaking those word in the input // Converts to segmentation training data by breaking those word in the input
......
...@@ -113,13 +113,13 @@ class TextFormatsTest(test_util.TensorFlowTestCase): ...@@ -113,13 +113,13 @@ class TextFormatsTest(test_util.TensorFlowTestCase):
# This test sentence includes a multiword token and an empty node, # This test sentence includes a multiword token and an empty node,
# both of which are to be ignored. # both of which are to be ignored.
test_sentence = """ test_sentence = """
1-2 We've _ 1-2\tWe've\t_
1 We we PRON PRP Case=Nom 3 nsubj _ SpaceAfter=No 1\tWe\twe\tPRON\tPRP\tCase=Nom\t3\tnsubj\t_\tSpaceAfter=No
2 've have AUX VBP Mood=Ind 3 aux _ _ 2\t've\thave\tAUX\tVBP\tMood=Ind\t3\taux\t_\t_
3 moved move VERB VBN Tense=Past 0 root _ _ 3\tmoved\tmove\tVERB\tVBN\tTense=Past\t0\troot\t_\t_
4 on on ADV RB _ 3 advmod _ SpaceAfter=No 4\ton\ton\tADV\tRB\t_\t3\tadvmod\t_\tSpaceAfter=No|foobar=baz
4.1 ignored ignore VERB VBN Tense=Past 0 _ _ _ 4.1\tignored\tignore\tVERB\tVBN\tTense=Past\t0\t_\t_\t_
5 . . PUNCT . _ 3 punct _ _ 5\t.\t.\tPUNCT\t.\t_\t3\tpunct\t_\t_
""" """
# Prepare test sentence. # Prepare test sentence.
...@@ -191,13 +191,13 @@ token { ...@@ -191,13 +191,13 @@ token {
self.assertEqual(expected_ends, [t.end for t in sentence_doc.token]) self.assertEqual(expected_ends, [t.end for t in sentence_doc.token])
def testSegmentationTrainingData(self): def testSegmentationTrainingData(self):
doc1_lines = ['测试 NO_SPACE\n', '的 NO_SPACE\n', '句子 NO_SPACE'] doc1_lines = ['测试\tNO_SPACE\n', '的\tNO_SPACE\n', '句子\tNO_SPACE']
doc1_text = '测试的句子' doc1_text = '测试的句子'
doc1_tokens = ['测', '试', '的', '句', '子'] doc1_tokens = ['测', '试', '的', '句', '子']
doc1_break_levles = [1, 0, 1, 1, 0] doc1_break_levles = [1, 0, 1, 1, 0]
doc2_lines = [ doc2_lines = [
'That NO_SPACE\n', '\'s SPACE\n', 'a SPACE\n', 'good SPACE\n', 'That\tNO_SPACE\n', '\'s\tSPACE\n', 'a\tSPACE\n', 'good\tSPACE\n',
'point NO_SPACE\n', '. NO_SPACE' 'point\tNO_SPACE\n', '.\tNO_SPACE'
] ]
doc2_text = 'That\'s a good point.' doc2_text = 'That\'s a good point.'
doc2_tokens = [ doc2_tokens = [
......
...@@ -16,6 +16,14 @@ py_library( ...@@ -16,6 +16,14 @@ py_library(
srcs = ["check.py"], srcs = ["check.py"],
) )
py_library(
name = "resources",
srcs = ["resources.py"],
visibility = ["//visibility:public"],
deps = [
],
)
py_library( py_library(
name = "pyregistry_test_base", name = "pyregistry_test_base",
testonly = 1, testonly = 1,
...@@ -56,3 +64,15 @@ py_test( ...@@ -56,3 +64,15 @@ py_test(
"@org_tensorflow//tensorflow/core:protos_all_py", "@org_tensorflow//tensorflow/core:protos_all_py",
], ],
) )
py_test(
name = "resources_test",
srcs = ["resources_test.py"],
data = [
"//syntaxnet:testdata/hello.txt",
],
deps = [
":resources",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
"""A component registry, similar to nlp_saft::RegisteredClass<>. # Copyright 2017 Google Inc. All Rights Reserved.
#
Like nlp_saft::RegisteredClass<>, one does not need to explicitly import the # Licensed under the Apache License, Version 2.0 (the "License");
module containing each subclass. It is sufficient to add subclasses as build # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A component registry, similar to RegisterableClass<>.
Like RegisterableClass<>, one does not need to explicitly import the module
containing each subclass. It is sufficient to add subclasses as build
dependencies. dependencies.
Unlike nlp_saft::RegisteredClass<>, which allows subclasses to be registered Unlike RegisterableClass<>, which allows subclasses to be registered under
under arbitrary names, subclasses must be looked up based on their type name. arbitrary names, subclasses must be looked up based on their type name. This
This restriction allows the registry to dynamically import the module containing restriction allows the registry to dynamically import the module containing the
the desired subclass. desired subclass.
Example usage: Example usage:
...@@ -82,7 +97,7 @@ def _GetClass(name): ...@@ -82,7 +97,7 @@ def _GetClass(name):
# Need at least "module.Class". # Need at least "module.Class".
if len(elements) < 2: if len(elements) < 2:
logging.debug('Malformed type: "%s"', name) logging.info('Malformed type: "%s"', name)
return None return None
module_path = '.'.join(elements[:-1]) module_path = '.'.join(elements[:-1])
class_name = elements[-1] class_name = elements[-1]
...@@ -91,20 +106,19 @@ def _GetClass(name): ...@@ -91,20 +106,19 @@ def _GetClass(name):
try: try:
__import__(module_path) __import__(module_path)
except ImportError as e: except ImportError as e:
logging.debug('Unable to find module "%s": "%s"', module_path, e) logging.info('Unable to find module "%s": "%s"', module_path, e)
return None return None
module = sys.modules[module_path] module = sys.modules[module_path]
# Look up the class. # Look up the class.
if not hasattr(module, class_name): if not hasattr(module, class_name):
logging.debug('Name "%s" not found in module: "%s"', class_name, logging.info('Name "%s" not found in module: "%s"', class_name, module_path)
module_path)
return None return None
class_obj = getattr(module, class_name) class_obj = getattr(module, class_name)
# Check that it is actually a class. # Check that it is actually a class.
if not inspect.isclass(class_obj): if not inspect.isclass(class_obj):
logging.debug('Name does not refer to a class: "%s"', name) logging.info('Name does not refer to a class: "%s"', name)
return None return None
return class_obj return class_obj
...@@ -125,8 +139,8 @@ def _Create(baseclass, subclass_name, *args, **kwargs): ...@@ -125,8 +139,8 @@ def _Create(baseclass, subclass_name, *args, **kwargs):
if subclass is None: if subclass is None:
return None # _GetClass() already logged an error return None # _GetClass() already logged an error
if not issubclass(subclass, baseclass): if not issubclass(subclass, baseclass):
logging.debug('Class "%s" is not a subclass of "%s"', subclass_name, logging.info('Class "%s" is not a subclass of "%s"', subclass_name,
baseclass.__name__) baseclass.__name__)
return None return None
return subclass(*args, **kwargs) return subclass(*args, **kwargs)
...@@ -135,13 +149,13 @@ def _ResolveAndCreate(baseclass, path, subclass_name, *args, **kwargs): ...@@ -135,13 +149,13 @@ def _ResolveAndCreate(baseclass, path, subclass_name, *args, **kwargs):
"""Resolves the name of a subclass and creates an instance of it. """Resolves the name of a subclass and creates an instance of it.
The subclass is resolved with respect to a package path in an inside-out The subclass is resolved with respect to a package path in an inside-out
manner. For example, if |path| is 'google3.foo.bar' and |subclass_name| is manner. For example, if |path| is 'syntaxnet.foo.bar' and |subclass_name| is
'baz.ClassName', then attempts are made to create instances of the following 'baz.ClassName', then attempts are made to create instances of the following
fully-qualified class names: fully-qualified class names:
'google3.foo.bar.baz.ClassName' 'syntaxnet.foo.bar.baz.ClassName'
'google3.foo.baz.ClassName' 'syntaxnet.foo.baz.ClassName'
'google3.baz.ClassName' 'syntaxnet.baz.ClassName'
'baz.ClassName' 'baz.ClassName'
An instance corresponding to the first successful attempt is returned. An instance corresponding to the first successful attempt is returned.
...@@ -163,9 +177,12 @@ def _ResolveAndCreate(baseclass, path, subclass_name, *args, **kwargs): ...@@ -163,9 +177,12 @@ def _ResolveAndCreate(baseclass, path, subclass_name, *args, **kwargs):
elements = path.split('.') elements = path.split('.')
while True: while True:
resolved_subclass_name = '.'.join(elements + [subclass_name]) resolved_subclass_name = '.'.join(elements + [subclass_name])
logging.info('Attempting to instantiate "%s"', resolved_subclass_name)
subclass = _Create(baseclass, resolved_subclass_name, *args, **kwargs) subclass = _Create(baseclass, resolved_subclass_name, *args, **kwargs)
if subclass: return subclass # success if subclass:
if not elements: break # no more paths to try return subclass # success
if not elements:
break # no more paths to try
elements.pop() # try resolving against the next-outer path elements.pop() # try resolving against the next-outer path
raise ValueError( raise ValueError(
'Failed to create subclass "%s" of base class %s using path %s' % 'Failed to create subclass "%s" of base class %s using path %s' %
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for registry system. """Tests for registry system.
This test uses two other modules: This test uses two other modules:
...@@ -104,9 +119,11 @@ class RegistryTest(googletest.TestCase): ...@@ -104,9 +119,11 @@ class RegistryTest(googletest.TestCase):
def testCannotResolveRelativeName(self): def testCannotResolveRelativeName(self):
"""Tests that Create fails if a relative path cannot be resolved.""" """Tests that Create fails if a relative path cannot be resolved."""
for name in [ for name in [
'nlp.saft.opensource.syntaxnet.util.registry_test_base.Impl', 'bad.syntaxnet.util.registry_test_base.Impl',
'saft.bad.registry_test_impl.Impl', 'missing.registry_test_impl.Impl', 'syntaxnet.bad.registry_test_impl.Impl',
'registry_test_impl.Bad', 'Impl' 'missing.registry_test_impl.Impl',
'registry_test_impl.Bad',
'Impl'
]: ]:
with self.assertRaisesRegexp(ValueError, 'Failed to create'): with self.assertRaisesRegexp(ValueError, 'Failed to create'):
registry_test_base.Base.Create(name, 'hello world') registry_test_base.Base.Create(name, 'hello world')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment