Commit a4bb31d0 authored by Terry Koo's avatar Terry Koo
Browse files

Export @195097388.

parent dea7ecf6
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns true if the |component_spec| has recurrent links.
bool IsRecurrent(const ComponentSpec &component_spec) {
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.source_component() == component_spec.name()) return true;
}
return false;
}
// Returns the sequence-based version of the |component_type| with specification
// |component_spec|, or an empty string if there is no sequence-based version.
string GetSequenceComponentType(const string &component_type,
const ComponentSpec &component_spec) {
// TODO(googleuser): Implement a SequenceDynamicComponent that can handle
// recurrent links. This may require changes to the NetworkUnit API.
static const char *kSupportedComponentTypes[] = {
"BulkDynamicComponent", //
"BulkLstmComponent", //
"MyelinDynamicComponent", //
};
for (const char *supported_type : kSupportedComponentTypes) {
if (component_type == supported_type) {
return tensorflow::strings::StrCat("Sequence", supported_type);
}
}
// Also support non-recurrent DynamicComponents. The BulkDynamicComponent
// requires determinism, but the SequenceBulkDynamicComponent does not, so
// it's not sufficient to only upgrade from BulkDynamicComponent.
if (component_type == "DynamicComponent" && !IsRecurrent(component_spec)) {
return "SequenceBulkDynamicComponent";
}
return string();
}
// Returns the |status| but coerces NOT_FOUND to OK. Sets |found| to false iff
// the |status| was NOT_FOUND.
tensorflow::Status AllowNotFound(const tensorflow::Status &status,
bool *found) {
*found = status.code() != tensorflow::error::NOT_FOUND;
return *found ? status : tensorflow::Status::OK();
}
// Transformer that checks whether a sequence-based component implementation
// could be used and, if compatible, modifies the ComponentSpec accordingly.
class SequenceComponentTransformer : public ComponentTransformer {
public:
// Implements ComponentTransformer.
tensorflow::Status Transform(const string &component_type,
ComponentSpec *component_spec) override;
};
tensorflow::Status SequenceComponentTransformer::Transform(
const string &component_type, ComponentSpec *component_spec) {
const int num_features = component_spec->fixed_feature_size() +
component_spec->linked_feature_size();
if (num_features == 0) return tensorflow::Status::OK();
// Look for supporting SequenceExtractors.
bool found = false;
string extractor_types;
for (const FixedFeatureChannel &channel : component_spec->fixed_feature()) {
string type;
TF_RETURN_IF_ERROR(AllowNotFound(
SequenceExtractor::Select(channel, *component_spec, &type), &found));
if (!found) return tensorflow::Status::OK();
tensorflow::strings::StrAppend(&extractor_types, type, ",");
}
if (!extractor_types.empty()) extractor_types.pop_back(); // remove comma
// Look for supporting SequenceLinkers.
string linker_types;
for (const LinkedFeatureChannel &channel : component_spec->linked_feature()) {
string type;
TF_RETURN_IF_ERROR(AllowNotFound(
SequenceLinker::Select(channel, *component_spec, &type), &found));
if (!found) return tensorflow::Status::OK();
tensorflow::strings::StrAppend(&linker_types, type, ",");
}
if (!linker_types.empty()) linker_types.pop_back(); // remove comma
// Look for a supporting SequencePredictor, if predictions are necessary.
string predictor_type;
if (!TransitionSystemTraits(*component_spec).is_deterministic) {
TF_RETURN_IF_ERROR(AllowNotFound(
SequencePredictor::Select(*component_spec, &predictor_type), &found));
if (!found) return tensorflow::Status::OK();
}
// Look for a supporting sequence-based component type.
const string sequence_component_type =
GetSequenceComponentType(component_type, *component_spec);
if (sequence_component_type.empty()) return tensorflow::Status::OK();
// Success; make modifications.
component_spec->mutable_backend()->set_registered_name("SequenceBackend");
RegisteredModuleSpec *builder = component_spec->mutable_component_builder();
builder->set_registered_name(sequence_component_type);
(*builder->mutable_parameters())["sequence_extractors"] = extractor_types;
(*builder->mutable_parameters())["sequence_linkers"] = linker_types;
(*builder->mutable_parameters())["sequence_predictor"] = predictor_type;
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(SequenceComponentTransformer);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Arbitrary supported component type.
constexpr char kSupportedComponentType[] = "MyelinDynamicComponent";
// Sequence-based version of the component type.
constexpr char kTransformedComponentType[] = "SequenceMyelinDynamicComponent";
// Trivial extractor that supports components named "supported".
class SupportIfNamedSupportedExtractor : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "supported";
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(SupportIfNamedSupportedExtractor);
// Trivial extractor that supports components if they have a resource. This is
// used to generate a "multiple supported extractors" conflict.
class SupportIfHasResourcesExtractor : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.resource_size() > 0;
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(SupportIfHasResourcesExtractor);
// Trivial linker that supports components named "supported".
class SupportIfNamedSupportedLinker : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "supported";
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(SupportIfNamedSupportedLinker);
// Trivial predictor that supports components named "supported".
class SupportIfNamedSupportedPredictor : public SequencePredictor {
public:
// Implements SequencePredictor.
bool Supports(const ComponentSpec &component_spec) const override {
return component_spec.name() == "supported";
}
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status Predict(Matrix<float>, InputBatchCache *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(SupportIfNamedSupportedPredictor);
// Returns a ComponentSpec that is supported by the transformer.
ComponentSpec MakeSupportedSpec() {
ComponentSpec component_spec;
component_spec.set_name("supported");
component_spec.set_num_actions(10);
component_spec.add_fixed_feature();
component_spec.add_fixed_feature();
component_spec.add_linked_feature();
component_spec.add_linked_feature();
component_spec.mutable_component_builder()->set_registered_name(
kSupportedComponentType);
return component_spec;
}
// Tests that a compatible spec is modified to use a new backend and component
// builder with SequenceExtractors, SequenceLinkers, and SequencePredictor.
TEST(SequenceComponentTransformerTest, Compatible) {
ComponentSpec component_spec = MakeSupportedSpec();
ComponentSpec modified_spec = component_spec;
modified_spec.mutable_backend()->set_registered_name("SequenceBackend");
modified_spec.mutable_component_builder()->set_registered_name(
kTransformedComponentType);
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_extractors",
"SupportIfNamedSupportedExtractor,SupportIfNamedSupportedExtractor"});
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_linkers",
"SupportIfNamedSupportedLinker,SupportIfNamedSupportedLinker"});
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_predictor", "SupportIfNamedSupportedPredictor"});
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(modified_spec));
}
// Tests that a compatible deterministic spec is modified to use a new backend
// and component builder with SequenceExtractors and SequenceLinkers only.
TEST(SequenceComponentTransformerTest, CompatibleNoPredictor) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.set_num_actions(1);
ComponentSpec modified_spec = component_spec;
modified_spec.mutable_backend()->set_registered_name("SequenceBackend");
modified_spec.mutable_component_builder()->set_registered_name(
kTransformedComponentType);
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_extractors",
"SupportIfNamedSupportedExtractor,SupportIfNamedSupportedExtractor"});
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_linkers",
"SupportIfNamedSupportedLinker,SupportIfNamedSupportedLinker"});
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_predictor", ""});
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(modified_spec));
}
// Tests that a ComponentSpec with no features is incompatible.
TEST(SequenceComponentTransformerTest, IncompatibleNoFeatures) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.clear_fixed_feature();
component_spec.clear_linked_feature();
const ComponentSpec unchanged_spec = component_spec;
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(unchanged_spec));
}
// Tests that a ComponentSpec with the wrong component builder is incompatible.
TEST(SequenceComponentTransformerTest, IncompatibleComponentBuilder) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_component_builder()->set_registered_name("bad");
const ComponentSpec unchanged_spec = component_spec;
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(unchanged_spec));
}
// Tests that a ComponentSpec is incompatible if it is not supported by any
// SequenceExtractor.
TEST(SequenceComponentTransformerTest,
IncompatibleNoSupportingSequenceExtractor) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.set_name("bad");
const ComponentSpec unchanged_spec = component_spec;
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(unchanged_spec));
}
// Tests that a ComponentSpec fails if multiple SequenceExtractors support it.
TEST(SequenceComponentTransformerTest,
FailIfMultipleSupportingSequenceExtractors) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.add_resource(); // triggers SupportIfHasResourcesExtractor
EXPECT_THAT(
ComponentTransformer::ApplyAll(&component_spec),
test::IsErrorWithSubstr("Multiple SequenceExtractors support channel"));
}
// Tests that a DynamicComponent is not upgraded if it is recurrent.
TEST(SequenceComponentTransformerTest, RecurrentDynamicComponent) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_component_builder()->set_registered_name(
"DynamicComponent");
component_spec.mutable_linked_feature(0)->set_source_component(
component_spec.name());
const ComponentSpec unchanged_spec = component_spec;
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(unchanged_spec));
}
// Tests that a DynamicComponent is upgraded to SequenceBulkDynamicComponent if
// it is non-recurrent.
TEST(SequenceComponentTransformerTest, NonRecurrentDynamicComponent) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_component_builder()->set_registered_name(
"DynamicComponent");
ComponentSpec modified_spec = component_spec;
modified_spec.mutable_backend()->set_registered_name("SequenceBackend");
modified_spec.mutable_component_builder()->set_registered_name(
"SequenceBulkDynamicComponent");
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_extractors",
"SupportIfNamedSupportedExtractor,SupportIfNamedSupportedExtractor"});
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_linkers",
"SupportIfNamedSupportedLinker,SupportIfNamedSupportedLinker"});
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_predictor", "SupportIfNamedSupportedPredictor"});
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(modified_spec));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_extractor.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status SequenceExtractor::Select(
const FixedFeatureChannel &channel, const ComponentSpec &component_spec,
string *name) {
string supporting_name;
for (const Registry::Registrar *registrar = registry()->components;
registrar != nullptr; registrar = registrar->next()) {
Factory *factory_function = registrar->object();
std::unique_ptr<SequenceExtractor> current_extractor(factory_function());
if (!current_extractor->Supports(channel, component_spec)) continue;
if (!supporting_name.empty()) {
return tensorflow::errors::Internal(
"Multiple SequenceExtractors support channel ",
channel.ShortDebugString(), " of ComponentSpec (", supporting_name,
" and ", registrar->name(), "): ", component_spec.ShortDebugString());
}
supporting_name = registrar->name();
}
if (supporting_name.empty()) {
return tensorflow::errors::NotFound(
"No SequenceExtractor supports channel ", channel.ShortDebugString(),
" of ComponentSpec: ", component_spec.ShortDebugString());
}
// Success; make modifications.
*name = supporting_name;
return tensorflow::Status::OK();
}
tensorflow::Status SequenceExtractor::New(
const string &name, const FixedFeatureChannel &channel,
const ComponentSpec &component_spec,
std::unique_ptr<SequenceExtractor> *extractor) {
std::unique_ptr<SequenceExtractor> matching_extractor;
TF_RETURN_IF_ERROR(
SequenceExtractor::CreateOrError(name, &matching_extractor));
TF_RETURN_IF_ERROR(matching_extractor->Initialize(channel, component_spec));
// Success; make modifications.
*extractor = std::move(matching_extractor);
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Sequence Extractor",
dragnn::runtime::SequenceExtractor);
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_EXTRACTOR_H_
#define DRAGNN_RUNTIME_SEQUENCE_EXTRACTOR_H_
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Interface for feature extraction for sequence inputs.
//
// This extractor can be used to avoid ComputeSession overhead in simple cases;
// for example, extracting a sequence of character or word IDs for an LSTM.
class SequenceExtractor : public RegisterableClass<SequenceExtractor> {
public:
// Sets |extractor| to an instance of the subclass named |name| initialized
// from the |channel| of the |component_spec|. On error, returns non-OK and
// modifies nothing.
static tensorflow::Status New(const string &name,
const FixedFeatureChannel &channel,
const ComponentSpec &component_spec,
std::unique_ptr<SequenceExtractor> *extractor);
SequenceExtractor(const SequenceExtractor &) = delete;
SequenceExtractor &operator=(const SequenceExtractor &) = delete;
virtual ~SequenceExtractor() = default;
// Sets |name| to the registered name of the SequenceExtractor that supports
// the |channel| of the |component_spec|. On error, returns non-OK and
// modifies nothing. The returned statuses include:
// * OK: If a supporting SequenceExtractor was found.
// * INTERNAL: If an error occurred while searching for a compatible match.
// * NOT_FOUND: If the search was error-free, but no compatible match was
// found.
static tensorflow::Status Select(const FixedFeatureChannel &channel,
const ComponentSpec &component_spec,
string *name);
// Overwrites |ids| with the sequence of features extracted from the |input|.
// On error, returns non-OK.
virtual tensorflow::Status GetIds(InputBatchCache *input,
std::vector<int32> *ids) const = 0;
protected:
SequenceExtractor() = default;
private:
// Helps prevent use of the Create() method; use New() instead.
using RegisterableClass<SequenceExtractor>::Create;
// Returns true if this supports the |channel| of the |component_spec|.
// Implementations must coordinate to ensure that at most one supports any
// given |component_spec|.
virtual bool Supports(const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) const = 0;
// Initializes this from the |channel| of the |component_spec|. On error,
// returns non-OK.
virtual tensorflow::Status Initialize(
const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) = 0;
};
} // namespace runtime
} // namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Sequence Extractor",
dragnn::runtime::SequenceExtractor);
} // namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::SequenceExtractor, #subclass, subclass)
#endif // DRAGNN_RUNTIME_SEQUENCE_EXTRACTOR_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_extractor.h"
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Supports components named "success" and initializes successfully.
class Success : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "success";
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(Success);
// Supports components named "failure" and fails to initialize.
class Failure : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "failure";
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::errors::Internal("Boom!");
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(Failure);
// Supports components named "duplicate" and initializes successfully.
class Duplicate : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "duplicate";
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(Duplicate);
// Duplicate of the above.
using Duplicate2 = Duplicate;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(Duplicate2);
// Tests that a component can be successfully created.
TEST(SequenceExtractorTest, Success) {
string name;
std::unique_ptr<SequenceExtractor> extractor;
ComponentSpec component_spec;
component_spec.set_name("success");
TF_ASSERT_OK(SequenceExtractor::Select({}, component_spec, &name));
ASSERT_EQ(name, "Success");
TF_EXPECT_OK(SequenceExtractor::New(name, {}, component_spec, &extractor));
EXPECT_NE(extractor, nullptr);
}
// Tests that errors in Initialize() are reported.
TEST(SequenceExtractorTest, FailToInitialize) {
string name;
std::unique_ptr<SequenceExtractor> extractor;
ComponentSpec component_spec;
component_spec.set_name("failure");
TF_ASSERT_OK(SequenceExtractor::Select({}, component_spec, &name));
EXPECT_EQ(name, "Failure");
EXPECT_THAT(SequenceExtractor::New(name, {}, component_spec, &extractor),
test::IsErrorWithSubstr("Boom!"));
EXPECT_EQ(extractor, nullptr);
}
// Tests that unsupported specs are reported as NOT_FOUND errors.
TEST(SequenceExtractorTest, UnsupportedSpec) {
string name = "not overwritten";
ComponentSpec component_spec;
component_spec.set_name("unsupported");
EXPECT_THAT(SequenceExtractor::Select({}, component_spec, &name),
test::IsErrorWithCodeAndSubstr(
tensorflow::error::NOT_FOUND,
"No SequenceExtractor supports channel"));
EXPECT_EQ(name, "not overwritten");
}
// Tests that unsupported subclass names are reported as errors.
TEST(SequenceExtractorTest, UnsupportedSubclass) {
std::unique_ptr<SequenceExtractor> extractor;
ComponentSpec component_spec;
EXPECT_THAT(
SequenceExtractor::New("Unsupported", {}, component_spec, &extractor),
test::IsErrorWithSubstr("Unknown DRAGNN Runtime Sequence Extractor"));
EXPECT_EQ(extractor, nullptr);
}
// Tests that multiple supporting extractors are reported as INTERNAL errors.
TEST(SequenceExtractorTest, Duplicate) {
string name = "not overwritten";
ComponentSpec component_spec;
component_spec.set_name("duplicate");
EXPECT_THAT(SequenceExtractor::Select({}, component_spec, &name),
test::IsErrorWithCodeAndSubstr(
tensorflow::error::INTERNAL,
"Multiple SequenceExtractors support channel"));
EXPECT_EQ(name, "not overwritten");
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_features.h"
#include <utility>
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status SequenceFeatureManager::Reset(
const FixedEmbeddingManager *fixed_embedding_manager,
const ComponentSpec &component_spec,
const std::vector<string> &sequence_extractor_types) {
const size_t num_channels = fixed_embedding_manager->channel_configs_.size();
if (component_spec.fixed_feature_size() != num_channels) {
return tensorflow::errors::InvalidArgument(
"Channel mismatch between FixedEmbeddingManager (", num_channels,
") and ComponentSpec (", component_spec.fixed_feature_size(), ")");
}
if (sequence_extractor_types.size() != num_channels) {
return tensorflow::errors::InvalidArgument(
"Channel mismatch between FixedEmbeddingManager (", num_channels,
") and SequenceExtractors (", sequence_extractor_types.size(), ")");
}
for (const FixedFeatureChannel &channel : component_spec.fixed_feature()) {
if (channel.size() > 1) {
return tensorflow::errors::InvalidArgument(
"Multi-embedding fixed features are not supported for channel: ",
channel.ShortDebugString());
}
}
std::vector<ChannelConfig> local_configs; // avoid modification on error
for (size_t channel_id = 0; channel_id < num_channels; ++channel_id) {
local_configs.emplace_back();
ChannelConfig &channel_config = local_configs.back();
const FixedEmbeddingManager::ChannelConfig &wrapped_config =
fixed_embedding_manager->channel_configs_[channel_id];
channel_config.is_embedded = wrapped_config.is_embedded;
channel_config.embedding_matrix = wrapped_config.embedding_matrix;
TF_RETURN_IF_ERROR(
SequenceExtractor::New(sequence_extractor_types[channel_id],
component_spec.fixed_feature(channel_id),
component_spec, &channel_config.extractor));
}
// Success; make modifications.
zeros_ = fixed_embedding_manager->zeros_.view();
channel_configs_ = std::move(local_configs);
return tensorflow::Status::OK();
}
tensorflow::Status SequenceFeatures::Reset(
const SequenceFeatureManager *manager, InputBatchCache *input) {
manager_ = manager;
zeros_ = manager->zeros_;
num_channels_ = manager->channel_configs_.size();
num_steps_ = 0;
// Make sure |channels_| is big enough. Note that |channels_| never shrinks,
// so the Channel.ids sub-vector is never deallocated.
if (num_channels_ > channels_.size()) channels_.resize(num_channels_);
for (int channel_id = 0; channel_id < num_channels_; ++channel_id) {
Channel &channel = channels_[channel_id];
const SequenceFeatureManager::ChannelConfig &channel_config =
manager->channel_configs_[channel_id];
channel.embedding_matrix = channel_config.embedding_matrix;
TF_RETURN_IF_ERROR(channel_config.extractor->GetIds(input, &channel.ids));
if (channel_id == 0) {
num_steps_ = channel.ids.size();
} else if (channel.ids.size() != num_steps_) {
return tensorflow::errors::FailedPrecondition(
"Inconsistent feature sequence lengths at channel ID ", channel_id,
": got ", channel.ids.size(), " but expected ", num_steps_);
}
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for configuring and extracting fixed embeddings for sequence-based
// models. Analogous to FixedEmbeddingManager and FixedEmbeddings, but uses
// SequenceExtractor instead of ComputeSession.
#ifndef DRAGNN_RUNTIME_SEQUENCE_FEATURES_H_
#define DRAGNN_RUNTIME_SEQUENCE_FEATURES_H_
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Manager for fixed embeddings for sequence-based models. This is a wrapper
// around the FixedEmbeddingManager.
class SequenceFeatureManager {
public:
// Creates an empty manager.
SequenceFeatureManager() = default;
// Resets this to wrap the |fixed_embedding_manager|, which must outlive this.
// The |sequence_extractor_types| should name one SequenceExtractor subclass
// per channel; e.g., "SyntaxNetCharacterSequenceExtractor". This initializes
// each SequenceExtractor from the |component_spec|. On error, returns non-OK
// and does not modify this.
tensorflow::Status Reset(
const FixedEmbeddingManager *fixed_embedding_manager,
const ComponentSpec &component_spec,
const std::vector<string> &sequence_extractor_types);
// Accessors.
size_t num_channels() const { return channel_configs_.size(); }
private:
friend class SequenceFeatures;
// Configuration for a single fixed embedding channel.
struct ChannelConfig {
// Whether this channel is embedded.
bool is_embedded = true;
// Embedding matrix of this channel. Only used if |is_embedded| is true.
Matrix<float> embedding_matrix;
// Extractor for sequences of feature IDs.
std::unique_ptr<SequenceExtractor> extractor;
};
// Array of zeros that can be substituted for missing feature IDs. This is a
// reference to the corresponding array in the FixedEmbeddingManager.
AlignedView zeros_;
// Ordered list of configurations for each channel.
std::vector<ChannelConfig> channel_configs_;
};
// A set of fixed embeddings for a sequence-based model. Configured by a
// SequenceFeatureManager.
class SequenceFeatures {
public:
// Creates an empty set of embeddings.
SequenceFeatures() = default;
// Resets this to the sequences of fixed features managed by the |manager| on
// the |input|. The |manager| must live until this is destroyed or Reset(),
// and should not be modified during that time. On error, returns non-OK.
tensorflow::Status Reset(const SequenceFeatureManager *manager,
InputBatchCache *input);
// Returns the feature ID or embedding for the |target_index|'th element of
// the |channel_id|'th channel. Each method is only valid for a non-embedded
// or embedded channel, respectively.
int32 GetId(size_t channel_id, size_t target_index) const;
Vector<float> GetEmbedding(size_t channel_id, size_t target_index) const;
// Accessors.
size_t num_channels() const { return num_channels_; }
size_t num_steps() const { return num_steps_; }
private:
// Data associated with a single fixed embedding channel.
struct Channel {
// Embedding matrix of this channel. Only used for embedded channels.
Matrix<float> embedding_matrix;
// Feature IDs for each step.
std::vector<int32> ids;
};
// Manager from the most recent Reset().
const SequenceFeatureManager *manager_ = nullptr;
// Zero vector from the most recent Reset().
AlignedView zeros_;
// Number of channels and steps from the most recent Reset().
size_t num_channels_ = 0;
size_t num_steps_ = 0;
// Ordered list of fixed embedding channels. This may contain more than
// |num_channels_| entries, to avoid deallocation/reallocation cycles, but
// only the first |num_channels_| entries are valid.
std::vector<Channel> channels_;
};
// Implementation details below.
inline int32 SequenceFeatures::GetId(size_t channel_id,
size_t target_index) const {
DCHECK_LT(channel_id, num_channels());
DCHECK_LT(target_index, num_steps());
DCHECK(!manager_->channel_configs_[channel_id].is_embedded);
const Channel &channel = channels_[channel_id];
return channel.ids[target_index];
}
inline Vector<float> SequenceFeatures::GetEmbedding(size_t channel_id,
size_t target_index) const {
DCHECK_LT(channel_id, num_channels());
DCHECK_LT(target_index, num_steps());
DCHECK(manager_->channel_configs_[channel_id].is_embedded);
const Channel &channel = channels_[channel_id];
const int32 id = channel.ids[target_index];
return id < 0 ? Vector<float>(zeros_, channel.embedding_matrix.num_columns())
: channel.embedding_matrix.row(id);
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_FEATURES_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_features.h"
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Number of transition steps to take in each component in the network.
const size_t kNumSteps = 10;
// A working one-channel ComponentSpec. This is intentionally identical to the
// first channel of |kMultiSpec|, so they can use the same embedding matrix.
const char kSingleSpec[] = R"(fixed_feature {
vocabulary_size: 13
embedding_dim: 11
size: 1
})";
const size_t kSingleRows = 13;
const size_t kSingleColumns = 11;
constexpr float kSingleValue = 1.25;
// A working multi-channel ComponentSpec.
const char kMultiSpec[] = R"(fixed_feature {
vocabulary_size: 13
embedding_dim: 11
size: 1
}
fixed_feature {
embedding_dim: -1
size: 1
}
fixed_feature {
embedding_dim: -1
size: 1
})";
// Fails to initialize.
class FailToInitialize : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &component_spec) const override {
LOG(FATAL) << "Should never be called.";
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::errors::Internal("No initialization for you!");
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
LOG(FATAL) << "Should never be called.";
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(FailToInitialize);
// Initializes OK, then fails to extract features.
class FailToGetIds : public FailToInitialize {
public:
// Implements SequenceExtractor.
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::errors::Internal("No features for you!");
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(FailToGetIds);
// Initializes OK and extracts the previous step.
class ExtractPrevious : public FailToGetIds {
public:
// Implements SequenceExtractor.
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *ids) const override {
ids->resize(kNumSteps);
for (int i = 0; i < kNumSteps; ++i) (*ids)[i] = i - 1;
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(ExtractPrevious);
// Initializes OK but produces the wrong number of features.
class WrongNumberOfIds : public FailToGetIds {
public:
// Implements SequenceExtractor.
tensorflow::Status GetIds(InputBatchCache *input,
std::vector<int32> *ids) const override {
ids->resize(kNumSteps + 1);
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(WrongNumberOfIds);
class SequenceFeatureManagerTest : public NetworkTestBase {
protected:
// Creates a SequenceFeatureManager and returns the result of Reset()-ing it
// using the |component_spec_text|.
tensorflow::Status ResetManager(
const string &component_spec_text,
const std::vector<string> &sequence_extractor_types) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
component_spec.set_name(kTestComponentName);
AddFixedEmbeddingMatrix(0, kSingleRows, kSingleColumns, kSingleValue);
AddComponent(kTestComponentName);
TF_RETURN_IF_ERROR(fixed_embedding_manager_.Reset(
component_spec, &variable_store_, &network_state_manager_));
return manager_.Reset(&fixed_embedding_manager_, component_spec,
sequence_extractor_types);
}
FixedEmbeddingManager fixed_embedding_manager_;
SequenceFeatureManager manager_;
};
// Tests that SequenceFeatureManager is empty by default.
TEST_F(SequenceFeatureManagerTest, EmptyByDefault) {
EXPECT_EQ(manager_.num_channels(), 0);
}
// Tests that SequenceFeatureManager is empty when reset to an empty spec.
TEST_F(SequenceFeatureManagerTest, EmptySpec) {
TF_EXPECT_OK(ResetManager("", {}));
EXPECT_EQ(manager_.num_channels(), 0);
}
// Tests that SequenceFeatureManager works with a single channel.
TEST_F(SequenceFeatureManagerTest, OneChannel) {
TF_EXPECT_OK(ResetManager(kSingleSpec, {"ExtractPrevious"}));
EXPECT_EQ(manager_.num_channels(), 1);
}
// Tests that SequenceFeatureManager works with multiple channels.
TEST_F(SequenceFeatureManagerTest, MultipleChannels) {
TF_EXPECT_OK(ResetManager(
kMultiSpec, {"ExtractPrevious", "ExtractPrevious", "ExtractPrevious"}));
EXPECT_EQ(manager_.num_channels(), 3);
}
// Tests that SequenceFeatureManager fails if the FixedEmbeddingManager and
// ComponentSpec are mismatched.
TEST_F(SequenceFeatureManagerTest, MismatchedFixedManagerAndComponentSpec) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(kMultiSpec, &component_spec));
component_spec.set_name(kTestComponentName);
AddFixedEmbeddingMatrix(0, kSingleRows, kSingleColumns, kSingleValue);
AddComponent(kTestComponentName);
TF_ASSERT_OK(fixed_embedding_manager_.Reset(component_spec, &variable_store_,
&network_state_manager_));
// Remove one fixed feature, resulting in a mismatch.
component_spec.mutable_fixed_feature()->RemoveLast();
EXPECT_THAT(
manager_.Reset(&fixed_embedding_manager_, component_spec,
{"ExtractPrevious", "ExtractPrevious", "ExtractPrevious"}),
test::IsErrorWithSubstr("Channel mismatch between FixedEmbeddingManager "
"(3) and ComponentSpec (2)"));
}
// Tests that SequenceFeatureManager fails if the FixedEmbeddingManager and
// SequenceExtractors are mismatched.
TEST_F(SequenceFeatureManagerTest,
MismatchedFixedManagerAndSequenceExtractors) {
EXPECT_THAT(
ResetManager(kMultiSpec, {"ExtractPrevious", "ExtractPrevious"}),
test::IsErrorWithSubstr("Channel mismatch between FixedEmbeddingManager "
"(3) and SequenceExtractors (2)"));
}
// Tests that SequenceFeatureManager fails if a channel has multiple embeddings.
TEST_F(SequenceFeatureManagerTest, UnsupportedMultiEmbeddingChannel) {
const string kBadSpec = R"(fixed_feature {
vocabulary_size: 13
embedding_dim: 11
size: 2 # bad
})";
EXPECT_THAT(ResetManager(kBadSpec, {"ExtractPrevious"}),
test::IsErrorWithSubstr(
"Multi-embedding fixed features are not supported"));
}
// Tests that SequenceFeatureManager fails if one of the SequenceExtractors
// fails to initialize.
TEST_F(SequenceFeatureManagerTest, FailToInitializeSequenceExtractor) {
EXPECT_THAT(ResetManager(kMultiSpec, {"ExtractPrevious", "FailToInitialize",
"ExtractPrevious"}),
test::IsErrorWithSubstr("No initialization for you!"));
}
// Tests that SequenceFeatureManager is OK even if the SequenceExtractors would
// fail in GetIds().
TEST_F(SequenceFeatureManagerTest, ManagerDoesntCareAboutGetIds) {
TF_EXPECT_OK(ResetManager(
kMultiSpec, {"FailToGetIds", "FailToGetIds", "FailToGetIds"}));
}
class SequenceFeaturesTest : public SequenceFeatureManagerTest {
protected:
// Resets the |sequence_features_| on the |manager_| and |input_batch_cache_|
// and returns the resulting status.
tensorflow::Status ResetFeatures() {
return sequence_features_.Reset(&manager_, &input_batch_cache_);
}
InputBatchCache input_batch_cache_;
SequenceFeatures sequence_features_;
};
// Tests that SequenceFeatures is empty by default.
TEST_F(SequenceFeaturesTest, EmptyByDefault) {
EXPECT_EQ(sequence_features_.num_channels(), 0);
EXPECT_EQ(sequence_features_.num_steps(), 0);
}
// Tests that SequenceFeatures is empty when reset by an empty manager.
TEST_F(SequenceFeaturesTest, EmptyManager) {
TF_ASSERT_OK(ResetManager("", {}));
TF_EXPECT_OK(ResetFeatures());
EXPECT_EQ(sequence_features_.num_channels(), 0);
EXPECT_EQ(sequence_features_.num_steps(), 0);
}
// Tests that SequenceFeatures fails when one of the SequenceExtractors fails.
TEST_F(SequenceFeaturesTest, FailToGetIds) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"ExtractPrevious", "ExtractPrevious", "FailToGetIds"}));
EXPECT_THAT(ResetFeatures(), test::IsErrorWithSubstr("No features for you!"));
}
// Tests that SequenceFeatures fails when the SequenceExtractors produce
// different numbers of features.
TEST_F(SequenceFeaturesTest, MismatchedNumbersOfFeatures) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"ExtractPrevious", "ExtractPrevious", "WrongNumberOfIds"}));
EXPECT_THAT(ResetFeatures(), test::IsErrorWithSubstr(
"Inconsistent feature sequence lengths at "
"channel ID 2: got 11 but expected 10"));
}
// Tests that SequenceFeatures works as expected on one channel.
TEST_F(SequenceFeaturesTest, SingleChannel) {
TF_ASSERT_OK(ResetManager(kSingleSpec, {"ExtractPrevious"}));
TF_ASSERT_OK(ResetFeatures());
ASSERT_EQ(sequence_features_.num_channels(), 1);
ASSERT_EQ(sequence_features_.num_steps(), kNumSteps);
// ExtractPrevious extracts -1 for the 0'th target index, which indicates a
// missing ID and should be mapped to a zero vector.
ExpectVector(sequence_features_.GetEmbedding(0, 0), kSingleColumns, 0.0);
EXPECT_DEBUG_DEATH(sequence_features_.GetId(0, 0), "is_embedded");
// The remaining feature IDs map to valid embedding rows.
for (int i = 1; i < kNumSteps; ++i) {
ExpectVector(sequence_features_.GetEmbedding(0, i), kSingleColumns,
kSingleValue);
EXPECT_DEBUG_DEATH(sequence_features_.GetId(0, i), "is_embedded");
}
}
// Tests that SequenceFeatures works as expected on multiple channels.
TEST_F(SequenceFeaturesTest, ManyChannels) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"ExtractPrevious", "ExtractPrevious", "ExtractPrevious"}));
TF_ASSERT_OK(ResetFeatures());
ASSERT_EQ(sequence_features_.num_channels(), 3);
ASSERT_EQ(sequence_features_.num_steps(), kNumSteps);
// ExtractPrevious extracts -1 for the 0'th target index, which indicates a
// missing ID and should be mapped to a zero vector.
ExpectVector(sequence_features_.GetEmbedding(0, 0), kSingleColumns, 0.0);
EXPECT_EQ(sequence_features_.GetId(1, 0), -1);
EXPECT_EQ(sequence_features_.GetId(2, 0), -1);
EXPECT_DEBUG_DEATH(sequence_features_.GetId(0, 0), "is_embedded");
EXPECT_DEBUG_DEATH(sequence_features_.GetEmbedding(1, 0), "is_embedded");
EXPECT_DEBUG_DEATH(sequence_features_.GetEmbedding(2, 0), "is_embedded");
// The remaining features point to the previous item.
for (int i = 1; i < kNumSteps; ++i) {
ExpectVector(sequence_features_.GetEmbedding(0, i), kSingleColumns,
kSingleValue);
EXPECT_EQ(sequence_features_.GetId(1, i), i - 1);
EXPECT_EQ(sequence_features_.GetId(2, i), i - 1);
EXPECT_DEBUG_DEATH(sequence_features_.GetId(0, i), "is_embedded");
EXPECT_DEBUG_DEATH(sequence_features_.GetEmbedding(1, i), "is_embedded");
EXPECT_DEBUG_DEATH(sequence_features_.GetEmbedding(2, i), "is_embedded");
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_linker.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status SequenceLinker::Select(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec,
string *name) {
string supporting_name;
for (const Registry::Registrar *registrar = registry()->components;
registrar != nullptr; registrar = registrar->next()) {
Factory *factory_function = registrar->object();
std::unique_ptr<SequenceLinker> current_linker(factory_function());
if (!current_linker->Supports(channel, component_spec)) continue;
if (!supporting_name.empty()) {
return tensorflow::errors::Internal(
"Multiple SequenceLinkers support channel ",
channel.ShortDebugString(), " of ComponentSpec (", supporting_name,
" and ", registrar->name(), "): ", component_spec.ShortDebugString());
}
supporting_name = registrar->name();
}
if (supporting_name.empty()) {
return tensorflow::errors::NotFound(
"No SequenceLinker supports channel ", channel.ShortDebugString(),
" of ComponentSpec: ", component_spec.ShortDebugString());
}
// Success; make modifications.
*name = supporting_name;
return tensorflow::Status::OK();
}
tensorflow::Status SequenceLinker::New(
const string &name, const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec,
std::unique_ptr<SequenceLinker> *linker) {
std::unique_ptr<SequenceLinker> matching_linker;
TF_RETURN_IF_ERROR(SequenceLinker::CreateOrError(name, &matching_linker));
TF_RETURN_IF_ERROR(matching_linker->Initialize(channel, component_spec));
// Success; make modifications.
*linker = std::move(matching_linker);
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Sequence Linker",
dragnn::runtime::SequenceLinker);
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_LINKER_H_
#define DRAGNN_RUNTIME_SEQUENCE_LINKER_H_
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Interface for link extraction for sequence inputs.
//
// This can be used to avoid ComputeSession overhead in simple cases; for
// example, extracting a sequence of identity or reverse-identity links.
class SequenceLinker : public RegisterableClass<SequenceLinker> {
public:
// Sets |linker| to an instance of the subclass named |name| initialized from
// the |channel| of the |component_spec|. On error, returns non-OK and
// modifies nothing.
static tensorflow::Status New(const string &name,
const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec,
std::unique_ptr<SequenceLinker> *linker);
SequenceLinker(const SequenceLinker &) = delete;
SequenceLinker &operator=(const SequenceLinker &) = delete;
virtual ~SequenceLinker() = default;
// Sets |name| to the registered name of the SequenceLinker that supports the
// |channel| of the |component_spec|. On error, returns non-OK and modifies
// nothing. The returned statuses include:
// * OK: If a supporting SequenceLinker was found.
// * INTERNAL: If an error occurred while searching for a compatible match.
// * NOT_FOUND: If the search was error-free, but no compatible match was
// found.
static tensorflow::Status Select(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec,
string *name);
// Overwrites |links| with the sequence of translated link step indices for
// the |input|. Specifically, sets links[i] to the (possibly out-of-bounds)
// step index to fetch from the source component for the i'th element of the
// target sequence. Assumes that |source_num_steps| is the number of steps
// taken by the source component. On error, returns non-OK.
virtual tensorflow::Status GetLinks(size_t source_num_steps,
InputBatchCache *input,
std::vector<int32> *links) const = 0;
protected:
SequenceLinker() = default;
private:
// Helps prevent use of the Create() method; use New() instead.
using RegisterableClass<SequenceLinker>::Create;
// Returns true if this supports the |channel| of the |component_spec|.
// Implementations must coordinate to ensure that at most one supports any
// given |component_spec|.
virtual bool Supports(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const = 0;
// Initializes this from the |channel| of the |component_spec|. On error,
// returns non-OK.
virtual tensorflow::Status Initialize(
const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) = 0;
};
} // namespace runtime
} // namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Sequence Linker",
dragnn::runtime::SequenceLinker);
} // namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::SequenceLinker, #subclass, subclass)
#endif // DRAGNN_RUNTIME_SEQUENCE_LINKER_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_linker.h"
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Supports components named "success" and initializes successfully.
class Success : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "success";
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(Success);
// Supports components named "failure" and fails to initialize.
class Failure : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "failure";
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::errors::Internal("Boom!");
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(Failure);
// Supports components named "duplicate" and initializes successfully.
class Duplicate : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "duplicate";
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(Duplicate);
// Duplicate of the above.
using Duplicate2 = Duplicate;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(Duplicate2);
// Tests that a component can be successfully created.
TEST(SequenceLinkerTest, Success) {
string name;
std::unique_ptr<SequenceLinker> linker;
ComponentSpec component_spec;
component_spec.set_name("success");
TF_ASSERT_OK(SequenceLinker::Select({}, component_spec, &name));
ASSERT_EQ(name, "Success");
TF_EXPECT_OK(SequenceLinker::New(name, {}, component_spec, &linker));
EXPECT_NE(linker, nullptr);
}
// Tests that errors in Initialize() are reported.
TEST(SequenceLinkerTest, FailToInitialize) {
string name;
std::unique_ptr<SequenceLinker> linker;
ComponentSpec component_spec;
component_spec.set_name("failure");
TF_ASSERT_OK(SequenceLinker::Select({}, component_spec, &name));
EXPECT_EQ(name, "Failure");
EXPECT_THAT(SequenceLinker::New(name, {}, component_spec, &linker),
test::IsErrorWithSubstr("Boom!"));
EXPECT_EQ(linker, nullptr);
}
// Tests that unsupported specs are reported as NOT_FOUND errors.
TEST(SequenceLinkerTest, UnsupportedSpec) {
string name = "not overwritten";
ComponentSpec component_spec;
component_spec.set_name("unsupported");
EXPECT_THAT(
SequenceLinker::Select({}, component_spec, &name),
test::IsErrorWithCodeAndSubstr(tensorflow::error::NOT_FOUND,
"No SequenceLinker supports channel"));
EXPECT_EQ(name, "not overwritten");
}
// Tests that unsupported subclass names are reported as errors.
TEST(SequenceLinkerTest, UnsupportedSubclass) {
std::unique_ptr<SequenceLinker> linker;
ComponentSpec component_spec;
EXPECT_THAT(
SequenceLinker::New("Unsupported", {}, component_spec, &linker),
test::IsErrorWithSubstr("Unknown DRAGNN Runtime Sequence Linker"));
EXPECT_EQ(linker, nullptr);
}
// Tests that multiple supporting linkers are reported as INTERNAL errors.
TEST(SequenceLinkerTest, Duplicate) {
string name = "not overwritten";
ComponentSpec component_spec;
component_spec.set_name("duplicate");
EXPECT_THAT(SequenceLinker::Select({}, component_spec, &name),
test::IsErrorWithCodeAndSubstr(
tensorflow::error::INTERNAL,
"Multiple SequenceLinkers support channel"));
EXPECT_EQ(name, "not overwritten");
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_links.h"
#include <utility>
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status SequenceLinkManager::Reset(
const LinkedEmbeddingManager *linked_embedding_manager,
const ComponentSpec &component_spec,
const std::vector<string> &sequence_linker_types) {
const size_t num_channels = linked_embedding_manager->channel_configs_.size();
if (component_spec.linked_feature_size() != num_channels) {
return tensorflow::errors::InvalidArgument(
"Channel mismatch between LinkedEmbeddingManager (", num_channels,
") and ComponentSpec (", component_spec.linked_feature_size(), ")");
}
if (sequence_linker_types.size() != num_channels) {
return tensorflow::errors::InvalidArgument(
"Channel mismatch between LinkedEmbeddingManager (", num_channels,
") and SequenceLinkers (", sequence_linker_types.size(), ")");
}
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.embedding_dim() >= 0) {
return tensorflow::errors::Unimplemented(
"Transformed linked features are not supported for channel: ",
channel.ShortDebugString());
}
}
std::vector<ChannelConfig> local_configs; // avoid modification on error
for (size_t channel_id = 0; channel_id < num_channels; ++channel_id) {
const LinkedFeatureChannel &channel =
component_spec.linked_feature(channel_id);
local_configs.emplace_back();
ChannelConfig &channel_config = local_configs.back();
channel_config.is_recurrent =
channel.source_component() == component_spec.name();
channel_config.handle =
linked_embedding_manager->channel_configs_[channel_id].source_handle;
TF_RETURN_IF_ERROR(
SequenceLinker::New(sequence_linker_types[channel_id],
component_spec.linked_feature(channel_id),
component_spec, &channel_config.linker));
}
// Success; make modifications.
zeros_ = linked_embedding_manager->zeros_.view();
channel_configs_ = std::move(local_configs);
return tensorflow::Status::OK();
}
tensorflow::Status SequenceLinks::Reset(bool add_steps,
const SequenceLinkManager *manager,
NetworkStates *network_states,
InputBatchCache *input) {
zeros_ = manager->zeros_;
num_channels_ = manager->channel_configs_.size();
num_steps_ = 0;
bool have_num_steps = false; // true if |num_steps_| was assigned
// Make sure |channels_| is big enough. Note that |channels_| never shrinks,
// so the Channel.links sub-vector is never deallocated.
if (num_channels_ > channels_.size()) channels_.resize(num_channels_);
// Process non-recurrent links first.
for (int channel_id = 0; channel_id < num_channels_; ++channel_id) {
const SequenceLinkManager::ChannelConfig &channel_config =
manager->channel_configs_[channel_id];
if (channel_config.is_recurrent) continue;
Channel &channel = channels_[channel_id];
channel.layer = network_states->GetLayer(channel_config.handle);
TF_RETURN_IF_ERROR(channel_config.linker->GetLinks(channel.layer.num_rows(),
input, &channel.links));
if (!have_num_steps) {
num_steps_ = channel.links.size();
have_num_steps = true;
} else if (channel.links.size() != num_steps_) {
return tensorflow::errors::FailedPrecondition(
"Inconsistent link sequence lengths at channel ID ", channel_id,
": got ", channel.links.size(), " but expected ", num_steps_);
}
}
// Add steps to the |network_states|, if requested.
if (add_steps) {
if (!have_num_steps) {
return tensorflow::errors::FailedPrecondition(
"Cannot infer the number of steps to add because there are no "
"non-recurrent links");
}
network_states->AddSteps(num_steps_);
}
// Process recurrent links. These require that the current component in the
// |network_states| has been sized to the proper number of steps.
for (int channel_id = 0; channel_id < num_channels_; ++channel_id) {
const SequenceLinkManager::ChannelConfig &channel_config =
manager->channel_configs_[channel_id];
if (!channel_config.is_recurrent) continue;
Channel &channel = channels_[channel_id];
channel.layer = network_states->GetLayer(channel_config.handle);
TF_RETURN_IF_ERROR(channel_config.linker->GetLinks(channel.layer.num_rows(),
input, &channel.links));
if (!have_num_steps) {
num_steps_ = channel.links.size();
have_num_steps = true;
} else if (channel.links.size() != num_steps_) {
return tensorflow::errors::FailedPrecondition(
"Inconsistent link sequence lengths at channel ID ", channel_id,
": got ", channel.links.size(), " but expected ", num_steps_);
}
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for configuring and extracting linked embeddings for sequence-based
// models. Analogous to LinkedEmbeddingManager and LinkedEmbeddings, but uses
// SequenceLinker instead of ComputeSession.
#ifndef DRAGNN_RUNTIME_SEQUENCE_LINKS_H_
#define DRAGNN_RUNTIME_SEQUENCE_LINKS_H_
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Manager for linked embeddings for sequence-based models. This is a wrapper
// around the LinkedEmbeddingManager.
class SequenceLinkManager {
public:
// Creates an empty manager.
SequenceLinkManager() = default;
// Resets this to wrap the |linked_embedding_manager|, which must outlive
// this. The |sequence_linker_types| should name one SequenceLinker subclass
// per channel; e.g., {"IdentitySequenceLinker", "ReversedSequenceLinker"}.
// This initializes each SequenceLinker from the |component_spec|. On error,
// returns non-OK and does not modify this.
tensorflow::Status Reset(
const LinkedEmbeddingManager *linked_embedding_manager,
const ComponentSpec &component_spec,
const std::vector<string> &sequence_linker_types);
// Accessors.
size_t num_channels() const { return channel_configs_.size(); }
private:
friend class SequenceLinks;
// Configuration for a single linked embedding channel.
struct ChannelConfig {
// Whether this link is recurrent.
bool is_recurrent = false;
// Handle to the source layer in the relevant NetworkStates.
LayerHandle<float> handle;
// Extractor for sequences of translated link indices.
std::unique_ptr<SequenceLinker> linker;
};
// Array of zeros that can be substituted for out-of-bounds embeddings. This
// is a reference to the corresponding array in the LinkedEmbeddingManager.
// See the large comment in linked_embeddings.cc for reference.
AlignedView zeros_;
// Ordered list of configurations for each channel.
std::vector<ChannelConfig> channel_configs_;
};
// A set of linked embeddings for a sequence-based model. Configured by a
// SequenceLinkManager.
class SequenceLinks {
public:
// Creates an empty set of embeddings.
SequenceLinks() = default;
// Resets this to the sequences of linked embeddings managed by the |manager|
// on the |input|. Retrieves layers from the |network_states|. The |manager|
// must live until this is destroyed or Reset(), and should not be modified
// during that time. If |add_steps| is true, then infers the number of steps
// from the non-recurrent links and adds steps to the |network_states| before
// processing the recurrent links. On error, returns non-OK.
//
// NB: Recurrent links are tricky, because the |network_states| must be filled
// with steps before processing recurrent links. There are two approaches:
// 1. Add steps to the |network_states| before calling Reset(). This only
// works if the component also has fixed features, which can be used to
// infer the number of steps.
// 2. Set |add_steps| to true, so steps are added during Reset(). This only
// works if the component also has non-recurrent links, which can be used
// to infer the number of steps.
// If a component only has recurrent links then neither of the above works,
// but such a component would be nonsensical: it recurses on itself with no
// external input.
tensorflow::Status Reset(bool add_steps, const SequenceLinkManager *manager,
NetworkStates *network_states,
InputBatchCache *input);
// Retrieves the linked embedding for the |target_index|'th element of the
// |channel_id|'th channel. Sets |embedding| to the linked embedding vector
// and sets |is_out_of_bounds| to true if the link is out of bounds.
void Get(size_t channel_id, size_t target_index, Vector<float> *embedding,
bool *is_out_of_bounds) const;
// Accessors.
size_t num_channels() const { return num_channels_; }
size_t num_steps() const { return num_steps_; }
private:
// Data associated with a single linked embedding channel.
struct Channel {
// Source layer activations.
Matrix<float> layer;
// Translated link indices for each step.
std::vector<int32> links;
};
// Zero vector from the most recent Reset().
AlignedView zeros_;
// Number of channels and steps from the most recent Reset().
size_t num_channels_ = 0;
size_t num_steps_ = 0;
// Ordered list of linked embedding channels. This may contain more than
// |num_channels_| entries, to avoid deallocation/reallocation cycles, but
// only the first |num_channels_| entries are valid.
std::vector<Channel> channels_;
};
// Implementation details below.
inline void SequenceLinks::Get(size_t channel_id, size_t target_index,
Vector<float> *embedding,
bool *is_out_of_bounds) const {
DCHECK_LT(channel_id, num_channels());
DCHECK_LT(target_index, num_steps());
const Channel &channel = channels_[channel_id];
const int32 link = channel.links[target_index];
*is_out_of_bounds = (link < 0 || link >= channel.layer.num_rows());
if (*is_out_of_bounds) {
*embedding = Vector<float>(zeros_, channel.layer.num_columns());
} else {
*embedding = channel.layer.row(link);
}
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_LINKS_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_links.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Dimensions of the layers in the network (see ResetManager() below).
const size_t kPrevious1LayerDim = 16;
const size_t kPrevious2LayerDim = 32;
const size_t kRecurrentLayerDim = 48;
// Number of transition steps to take in each component in the network.
const size_t kNumSteps = 10;
// A working one-channel ComponentSpec.
const char kSingleSpec[] = R"(linked_feature {
embedding_dim: -1
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
})";
// A working multi-channel ComponentSpec.
const char kMultiSpec[] = R"(linked_feature {
embedding_dim: -1
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'source_component_2'
source_layer: 'previous_2'
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'test_component'
source_layer: 'recurrent'
size: 1
})";
// A recurrent-only ComponentSpec.
const char kRecurrentSpec[] = R"(linked_feature {
embedding_dim: -1
source_component: 'test_component'
source_layer: 'recurrent'
size: 1
})";
// Fails to initialize.
class FailToInitialize : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &component_spec) const override {
LOG(FATAL) << "Should never be called.";
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::errors::Internal("No initialization for you!");
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *) const override {
LOG(FATAL) << "Should never be called.";
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(FailToInitialize);
// Initializes OK, then fails to extract links.
class FailToGetLinks : public FailToInitialize {
public:
// Implements SequenceLinker.
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::errors::Internal("No links for you!");
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(FailToGetLinks);
// Initializes OK and links to the previous step.
class LinkToPrevious : public FailToGetLinks {
public:
// Implements SequenceLinker.
tensorflow::Status GetLinks(size_t source_num_steps, InputBatchCache *,
std::vector<int32> *links) const override {
links->resize(source_num_steps);
for (int i = 0; i < links->size(); ++i) (*links)[i] = i - 1;
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(LinkToPrevious);
// Initializes OK but produces the wrong number of links.
class WrongNumberOfLinks : public FailToGetLinks {
public:
// Implements SequenceLinker.
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *links) const override {
links->resize(kNumSteps + 1);
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(WrongNumberOfLinks);
class SequenceLinkManagerTest : public NetworkTestBase {
protected:
// Sets up previous components and layers.
void AddComponentsAndLayers() {
AddComponent("source_component_0");
AddComponent("source_component_1");
AddLayer("previous_1", kPrevious1LayerDim);
AddComponent("source_component_2");
AddLayer("previous_2", kPrevious2LayerDim);
AddComponent(kTestComponentName);
AddLayer("recurrent", kRecurrentLayerDim);
}
// Creates a SequenceLinkManager and returns the result of Reset()-ing it
// using the |component_spec_text|.
tensorflow::Status ResetManager(
const string &component_spec_text,
const std::vector<string> &sequence_linker_types) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
component_spec.set_name(kTestComponentName);
AddComponentsAndLayers();
TF_RETURN_IF_ERROR(linked_embedding_manager_.Reset(
component_spec, &variable_store_, &network_state_manager_));
return manager_.Reset(&linked_embedding_manager_, component_spec,
sequence_linker_types);
}
LinkedEmbeddingManager linked_embedding_manager_;
SequenceLinkManager manager_;
};
// Tests that SequenceLinkManager is empty by default.
TEST_F(SequenceLinkManagerTest, EmptyByDefault) {
EXPECT_EQ(manager_.num_channels(), 0);
}
// Tests that SequenceLinkManager is empty when reset to an empty spec.
TEST_F(SequenceLinkManagerTest, EmptySpec) {
TF_EXPECT_OK(ResetManager("", {}));
EXPECT_EQ(manager_.num_channels(), 0);
}
// Tests that SequenceLinkManager works with a single channel.
TEST_F(SequenceLinkManagerTest, OneChannel) {
TF_EXPECT_OK(ResetManager(kSingleSpec, {"LinkToPrevious"}));
EXPECT_EQ(manager_.num_channels(), 1);
}
// Tests that SequenceLinkManager works with multiple channels.
TEST_F(SequenceLinkManagerTest, MultipleChannels) {
TF_EXPECT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "LinkToPrevious", "LinkToPrevious"}));
EXPECT_EQ(manager_.num_channels(), 3);
}
// Tests that SequenceLinkManager fails if the LinkedEmbeddingManager and
// ComponentSpec are mismatched.
TEST_F(SequenceLinkManagerTest, MismatchedLinkedManagerAndComponentSpec) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(kMultiSpec, &component_spec));
component_spec.set_name(kTestComponentName);
AddComponentsAndLayers();
TF_ASSERT_OK(linked_embedding_manager_.Reset(component_spec, &variable_store_,
&network_state_manager_));
// Remove one linked feature, resulting in a mismatch.
component_spec.mutable_linked_feature()->RemoveLast();
EXPECT_THAT(
manager_.Reset(&linked_embedding_manager_, component_spec,
{"LinkToPrevious", "LinkToPrevious", "LinkToPrevious"}),
test::IsErrorWithSubstr("Channel mismatch between LinkedEmbeddingManager "
"(3) and ComponentSpec (2)"));
}
// Tests that SequenceLinkManager fails if the LinkedEmbeddingManager and
// SequenceLinkers are mismatched.
TEST_F(SequenceLinkManagerTest, MismatchedLinkedManagerAndSequenceLinkers) {
EXPECT_THAT(
ResetManager(kMultiSpec, {"LinkToPrevious", "LinkToPrevious"}),
test::IsErrorWithSubstr("Channel mismatch between LinkedEmbeddingManager "
"(3) and SequenceLinkers (2)"));
}
// Tests that SequenceLinkManager fails when the link is transformed.
TEST_F(SequenceLinkManagerTest, UnsupportedTransformedLink) {
const string kBadSpec = R"(linked_feature {
embedding_dim: 16 # bad
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
})";
AddLinkedWeightMatrix(0, kPrevious1LayerDim, 16, 0.0);
AddLinkedOutOfBoundsVector(0, 16, 0.0);
EXPECT_THAT(
ResetManager(kBadSpec, {"LinkToPrevious"}),
test::IsErrorWithSubstr("Transformed linked features are not supported"));
}
// Tests that SequenceLinkManager fails if one of the SequenceLinkers fails to
// initialize.
TEST_F(SequenceLinkManagerTest, FailToInitializeSequenceLinker) {
EXPECT_THAT(ResetManager(kMultiSpec, {"LinkToPrevious", "FailToInitialize",
"LinkToPrevious"}),
test::IsErrorWithSubstr("No initialization for you!"));
}
// Tests that SequenceLinkManager is OK even if the SequenceLinkers would fail
// in GetLinks().
TEST_F(SequenceLinkManagerTest, ManagerDoesntCareAboutGetLinks) {
TF_EXPECT_OK(ResetManager(
kMultiSpec, {"FailToGetLinks", "FailToGetLinks", "FailToGetLinks"}));
}
// Values to fill each layer with.
const float kPrevious1LayerValue = 1.0;
const float kPrevious2LayerValue = 2.0;
const float kRecurrentLayerValue = 3.0;
class SequenceLinksTest : public SequenceLinkManagerTest {
protected:
// Resets the |sequence_links_| using the |manager_|, |network_states_|, and
// |input_batch_cache_|, and returns the resulting status. Passes |add_steps|
// to Reset() and advances the current component by |num_steps|.
tensorflow::Status ResetLinks(bool add_steps = false,
size_t num_steps = kNumSteps) {
network_states_.Reset(&network_state_manager_);
// Fill components with steps.
StartComponent(kNumSteps); // source_component_0
StartComponent(kNumSteps); // source_component_1
StartComponent(kNumSteps); // source_component_2
StartComponent(num_steps); // current component
// Fill layers with values.
FillLayer("source_component_1", "previous_1", kPrevious1LayerValue);
FillLayer("source_component_2", "previous_2", kPrevious2LayerValue);
FillLayer(kTestComponentName, "recurrent", kRecurrentLayerValue);
return sequence_links_.Reset(add_steps, &manager_, &network_states_,
&input_batch_cache_);
}
InputBatchCache input_batch_cache_;
SequenceLinks sequence_links_;
};
// Tests that SequenceLinks is empty by default.
TEST_F(SequenceLinksTest, EmptyByDefault) {
EXPECT_EQ(sequence_links_.num_channels(), 0);
EXPECT_EQ(sequence_links_.num_steps(), 0);
}
// Tests that SequenceLinks is empty when reset by an empty manager.
TEST_F(SequenceLinksTest, EmptyManager) {
TF_ASSERT_OK(ResetManager("", {}));
TF_EXPECT_OK(ResetLinks());
EXPECT_EQ(sequence_links_.num_channels(), 0);
EXPECT_EQ(sequence_links_.num_steps(), 0);
}
// Tests that SequenceLinks fails when one of the non-recurrent SequenceLinkers
// fails.
TEST_F(SequenceLinksTest, FailToGetNonRecurrentLinks) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "FailToGetLinks", "LinkToPrevious"}));
EXPECT_THAT(ResetLinks(), test::IsErrorWithSubstr("No links for you!"));
}
// Tests that SequenceLinks fails when one of the recurrent SequenceLinkers
// fails.
TEST_F(SequenceLinksTest, FailToGetRecurrentLinks) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "LinkToPrevious", "FailToGetLinks"}));
EXPECT_THAT(ResetLinks(), test::IsErrorWithSubstr("No links for you!"));
}
// Tests that SequenceLinks fails when the non-recurrent SequenceLinkers produce
// different numbers of links.
TEST_F(SequenceLinksTest, MismatchedNumbersOfNonRecurrentLinks) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "WrongNumberOfLinks", "LinkToPrevious"}));
EXPECT_THAT(ResetLinks(),
test::IsErrorWithSubstr("Inconsistent link sequence lengths at "
"channel ID 1: got 11 but expected 10"));
}
// Tests that SequenceLinks fails when the recurrent SequenceLinkers produce
// different numbers of links.
TEST_F(SequenceLinksTest, MismatchedNumbersOfRecurrentLinks) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "LinkToPrevious", "WrongNumberOfLinks"}));
EXPECT_THAT(ResetLinks(),
test::IsErrorWithSubstr("Inconsistent link sequence lengths at "
"channel ID 2: got 11 but expected 10"));
}
// Tests that SequenceLinks works as expected on one channel.
TEST_F(SequenceLinksTest, SingleChannel) {
TF_ASSERT_OK(ResetManager(kSingleSpec, {"LinkToPrevious"}));
TF_ASSERT_OK(ResetLinks());
ASSERT_EQ(sequence_links_.num_channels(), 1);
ASSERT_EQ(sequence_links_.num_steps(), kNumSteps);
const Matrix<float> previous1(GetLayer("source_component_1", "previous_1"));
Vector<float> embedding;
bool is_out_of_bounds = false;
// LinkToPrevious links the 0'th index to -1, which is out of bounds.
sequence_links_.Get(0, 0, &embedding, &is_out_of_bounds);
EXPECT_TRUE(is_out_of_bounds);
ExpectVector(embedding, kPrevious1LayerDim, 0.0);
// The remaining links point to the previous item.
for (int i = 1; i < kNumSteps; ++i) {
sequence_links_.Get(0, i, &embedding, &is_out_of_bounds);
EXPECT_FALSE(is_out_of_bounds);
ExpectVector(embedding, kPrevious1LayerDim, kPrevious1LayerValue);
EXPECT_EQ(embedding.data(), previous1.row(i - 1).data());
}
}
// Tests that SequenceLinks works as expected on multiple channels.
TEST_F(SequenceLinksTest, ManyChannels) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "LinkToPrevious", "LinkToPrevious"}));
TF_ASSERT_OK(ResetLinks());
ASSERT_EQ(sequence_links_.num_channels(), 3);
ASSERT_EQ(sequence_links_.num_steps(), kNumSteps);
const Matrix<float> previous1(GetLayer("source_component_1", "previous_1"));
const Matrix<float> previous2(GetLayer("source_component_2", "previous_2"));
const Matrix<float> recurrent(GetLayer(kTestComponentName, "recurrent"));
Vector<float> embedding;
bool is_out_of_bounds = false;
// LinkToPrevious links the 0'th index to -1, which is out of bounds.
sequence_links_.Get(0, 0, &embedding, &is_out_of_bounds);
EXPECT_TRUE(is_out_of_bounds);
ExpectVector(embedding, kPrevious1LayerDim, 0.0);
sequence_links_.Get(1, 0, &embedding, &is_out_of_bounds);
EXPECT_TRUE(is_out_of_bounds);
ExpectVector(embedding, kPrevious2LayerDim, 0.0);
sequence_links_.Get(2, 0, &embedding, &is_out_of_bounds);
EXPECT_TRUE(is_out_of_bounds);
ExpectVector(embedding, kRecurrentLayerDim, 0.0);
// The remaining links point to the previous item.
for (int i = 1; i < kNumSteps; ++i) {
sequence_links_.Get(0, i, &embedding, &is_out_of_bounds);
EXPECT_FALSE(is_out_of_bounds);
ExpectVector(embedding, kPrevious1LayerDim, kPrevious1LayerValue);
EXPECT_EQ(embedding.data(), previous1.row(i - 1).data());
sequence_links_.Get(1, i, &embedding, &is_out_of_bounds);
EXPECT_FALSE(is_out_of_bounds);
ExpectVector(embedding, kPrevious2LayerDim, kPrevious2LayerValue);
EXPECT_EQ(embedding.data(), previous2.row(i - 1).data());
sequence_links_.Get(2, i, &embedding, &is_out_of_bounds);
EXPECT_FALSE(is_out_of_bounds);
ExpectVector(embedding, kRecurrentLayerDim, kRecurrentLayerValue);
EXPECT_EQ(embedding.data(), recurrent.row(i - 1).data());
}
}
// Tests that SequenceLinks is emptied when resetting to an empty manager after
// being reset to a non-empty manager.
TEST_F(SequenceLinksTest, ResetToEmptyAfterNonEmpty) {
TF_ASSERT_OK(ResetManager(kSingleSpec, {"LinkToPrevious"}));
TF_ASSERT_OK(ResetLinks());
ASSERT_EQ(sequence_links_.num_channels(), 1);
ASSERT_EQ(sequence_links_.num_steps(), kNumSteps);
SequenceLinkManager manager;
TF_ASSERT_OK(sequence_links_.Reset(/*add_steps=*/false, &manager,
&network_states_, &input_batch_cache_));
ASSERT_EQ(sequence_links_.num_channels(), 0);
ASSERT_EQ(sequence_links_.num_steps(), 0);
}
// Tests that SequenceLinks fails when adding steps to a component with no
// non-recurrent links.
TEST_F(SequenceLinksTest, AddStepsWithNoNonRecurrentLinks) {
TF_ASSERT_OK(ResetManager(kRecurrentSpec, {"LinkToPrevious"}));
EXPECT_THAT(
ResetLinks(/*add_steps=*/true),
test::IsErrorWithSubstr("Cannot infer the number of steps to add because "
"there are no non-recurrent links"));
}
// Tests that SequenceLinks produces no links when processing a component with
// only recurrent links, and when the NetworkStates has no steps.
TEST_F(SequenceLinksTest, RecurrentLinksWithNoSteps) {
TF_ASSERT_OK(ResetManager(kRecurrentSpec, {"LinkToPrevious"}));
TF_ASSERT_OK(ResetLinks(/*add_steps=*/false, /*num_steps=*/0));
ASSERT_EQ(sequence_links_.num_channels(), 1);
ASSERT_EQ(sequence_links_.num_steps(), 0);
}
// Tests that SequenceLinks properly infers the number of steps and adds them
// when processing a component with both non-recurrent and recurrent links.
TEST_F(SequenceLinksTest, AddStepsWithNonRecurrentAndRecurrentLinks) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "LinkToPrevious", "LinkToPrevious"}));
TF_ASSERT_OK(ResetLinks(/*add_steps=*/true, /*num_steps=*/0));
ASSERT_EQ(sequence_links_.num_channels(), 3);
ASSERT_EQ(sequence_links_.num_steps(), kNumSteps);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_model.h"
#include <vector>
#include "dragnn/runtime/attributes.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Proper backend for sequence-based models.
constexpr char kSupportedBackend[] = "SequenceBackend";
// Attributes for sequence-based comopnents, attached to the component builder.
// See SequenceComponentTransformer.
struct ComponentBuilderAttributes : public Attributes {
// Registered names of the sequence extractors to use.
Mandatory<std::vector<string>> sequence_extractors{"sequence_extractors",
this};
// Registered names of the sequence linkers to use per channel, in order.
Mandatory<std::vector<string>> sequence_linkers{"sequence_linkers", this};
// Registered name of the sequence predictor to use.
Mandatory<string> sequence_predictor{"sequence_predictor", this};
};
} // namespace
bool SequenceModel::Supports(const ComponentSpec &component_spec) {
// Require single-embedding fixed and linked features.
for (const FixedFeatureChannel &channel : component_spec.fixed_feature()) {
if (channel.size() != 1) return false;
}
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.size() != 1) return false;
}
const bool has_fixed_feature = component_spec.fixed_feature_size() > 0;
bool has_recurrent_link = false;
bool has_non_recurrent_link = false;
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.source_component() == component_spec.name()) {
has_recurrent_link = true;
} else {
has_non_recurrent_link = true;
}
}
// Recurrent links must be accompanied by fixed features or non-recurrent
// links, so the number of recurrent steps can be pre-computed.
if (has_recurrent_link && !has_fixed_feature && !has_non_recurrent_link) {
return false;
}
const int num_features = component_spec.fixed_feature_size() +
component_spec.linked_feature_size();
return component_spec.backend().registered_name() == kSupportedBackend &&
num_features > 0;
}
tensorflow::Status SequenceModel::Initialize(
const ComponentSpec &component_spec, const string &logits_name,
const FixedEmbeddingManager *fixed_embedding_manager,
const LinkedEmbeddingManager *linked_embedding_manager,
NetworkStateManager *network_state_manager) {
component_name_ = component_spec.name();
if (component_spec.backend().registered_name() != kSupportedBackend) {
return tensorflow::errors::InvalidArgument(
"Invalid component backend: ",
component_spec.backend().registered_name());
}
TransitionSystemTraits traits(component_spec);
deterministic_ = traits.is_deterministic;
left_to_right_ = traits.is_left_to_right;
ComponentBuilderAttributes component_builder_attributes;
TF_RETURN_IF_ERROR(component_builder_attributes.Reset(
component_spec.component_builder().parameters()));
TF_RETURN_IF_ERROR(sequence_feature_manager_.Reset(
fixed_embedding_manager, component_spec,
component_builder_attributes.sequence_extractors()));
TF_RETURN_IF_ERROR(sequence_link_manager_.Reset(
linked_embedding_manager, component_spec,
component_builder_attributes.sequence_linkers()));
have_fixed_features_ = sequence_feature_manager_.num_channels() > 0;
have_linked_features_ = sequence_link_manager_.num_channels() > 0;
if (!have_fixed_features_ && !have_linked_features_) {
return tensorflow::errors::InvalidArgument("No fixed or linked features");
}
if (!deterministic_) {
size_t dimension = 0;
TF_RETURN_IF_ERROR(network_state_manager->LookupLayer(
component_name_, logits_name, &dimension, &logits_handle_));
if (dimension != component_spec.num_actions()) {
return tensorflow::errors::InvalidArgument(
"Logits dimension mismatch between NetworkStates (", dimension,
") and ComponentSpec (", component_spec.num_actions(), ")");
}
TF_RETURN_IF_ERROR(SequencePredictor::New(
component_builder_attributes.sequence_predictor(), component_spec,
&sequence_predictor_));
}
return tensorflow::Status::OK();
}
tensorflow::Status SequenceModel::Preprocess(
SessionState *session_state, ComputeSession *compute_session,
EvaluateState *evaluate_state) const {
InputBatchCache *input_batch_cache = compute_session->GetInputBatchCache();
if (input_batch_cache == nullptr) {
return tensorflow::errors::InvalidArgument("Null input batch");
}
// The feature handling below is complicated by the need to support recurrent
// links. See the comment on SequenceLinks::Reset().
NetworkStates &network_states = session_state->network_states;
TF_RETURN_IF_ERROR(evaluate_state->features.Reset(&sequence_feature_manager_,
input_batch_cache));
if (have_fixed_features_) {
network_states.AddSteps(evaluate_state->features.num_steps());
}
TF_RETURN_IF_ERROR(evaluate_state->links.Reset(
/*add_steps=*/!have_fixed_features_, &sequence_link_manager_,
&network_states, input_batch_cache));
// Initialize() ensures that there is at least one fixed or linked feature;
// use it to determine the number of steps.
size_t num_steps = 0;
if (have_fixed_features_ && have_linked_features_) {
num_steps = evaluate_state->features.num_steps();
if (num_steps != evaluate_state->links.num_steps()) {
return tensorflow::errors::FailedPrecondition(
"Sequence length mismatch between fixed features (", num_steps,
") and linked features (", evaluate_state->links.num_steps(), ")");
}
} else if (have_fixed_features_) {
num_steps = evaluate_state->features.num_steps();
} else {
num_steps = evaluate_state->links.num_steps();
}
// Tell the backend the current input size, so it can handle requests for
// linked features from downstream components.
static_cast<SequenceBackend *>(
compute_session->GetReadiedComponent(component_name_))
->SetSequenceSize(num_steps);
evaluate_state->num_steps = num_steps;
evaluate_state->input = input_batch_cache;
return tensorflow::Status::OK();
}
tensorflow::Status SequenceModel::Predict(const NetworkStates &network_states,
EvaluateState *evaluate_state) const {
if (!deterministic_) {
const Matrix<float> logits(network_states.GetLayer(logits_handle_));
TF_RETURN_IF_ERROR(
sequence_predictor_->Predict(logits, evaluate_state->input));
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_MODEL_H_
#define DRAGNN_RUNTIME_SEQUENCE_MODEL_H_
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_features.h"
#include "dragnn/runtime/sequence_links.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/session_state.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A class that configures and helps evaluate a sequence-based model.
//
// This class requires the SequenceBackend component backend and elides most of
// the ComputeSession feature extraction and transition system overhead.
class SequenceModel {
public:
// State associated with a single evaluation of the model.
struct EvaluateState {
// Number of transition steps in the current sequence.
size_t num_steps = 0;
// Current input batch.
InputBatchCache *input = nullptr;
// Sequence-based fixed features.
SequenceFeatures features;
// Sequence-based linked embeddings.
SequenceLinks links;
};
// Creates an uninitialized model. Call Initialize() before use.
SequenceModel() = default;
// Returns true if the |component_spec| is compatible with a sequence model.
static bool Supports(const ComponentSpec &component_spec);
// Initalizes this from the configuration in the |component_spec|. Wraps the
// |fixed_embedding_manager| and |linked_embedding_manager| in sequence-based
// versions, and requests layers from the |network_state_manager|. All of the
// managers must outlive this. If the transition system is non-deterministic,
// uses the layer named |logits_name| to make predictions later in Predict();
// otherwise, |logits_name| is ignored and Predict() does nothing. On error,
// returns non-OK.
tensorflow::Status Initialize(
const ComponentSpec &component_spec, const string &logits_name,
const FixedEmbeddingManager *fixed_embedding_manager,
const LinkedEmbeddingManager *linked_embedding_manager,
NetworkStateManager *network_state_manager);
// Resets the |evaluate_state| to values derived from the |session_state| and
// |compute_session|. Also updates the NetworkStates in the |session_state|
// and the current component of the |compute_session| with the length of the
// current sequence. Call this before producing output layers. On error,
// returns non-OK.
tensorflow::Status Preprocess(SessionState *session_state,
ComputeSession *compute_session,
EvaluateState *evaluate_state) const;
// If applicable, makes predictions based on the logits in |network_states|
// and applies them to the input in the |evaluate_state|. Call this after
// producing output layers. On error, returns non-OK.
tensorflow::Status Predict(const NetworkStates &network_states,
EvaluateState *evaluate_state) const;
// Accessors.
bool deterministic() const { return deterministic_; }
bool left_to_right() const { return left_to_right_; }
const SequenceLinkManager &sequence_link_manager() const;
const SequenceFeatureManager &sequence_feature_manager() const;
private:
// Name of the component that this model is a part of.
string component_name_;
// Whether the underlying transition system is deterministic.
bool deterministic_ = false;
// Whether to process sequences from left to right.
bool left_to_right_ = true;
// Whether fixed or linked features are present.
bool have_fixed_features_ = false;
bool have_linked_features_ = false;
// Handle to the logits layer. Only used if |deterministic_| is false.
LayerHandle<float> logits_handle_;
// Manager for sequence-based feature extractors.
SequenceFeatureManager sequence_feature_manager_;
// Manager for sequence-based linked embeddings.
SequenceLinkManager sequence_link_manager_;
// Sequence-based predictor, if |deterministic_| is false.
std::unique_ptr<SequencePredictor> sequence_predictor_;
};
// Implementation details below.
inline const SequenceLinkManager &SequenceModel::sequence_link_manager() const {
return sequence_link_manager_;
}
inline const SequenceFeatureManager &SequenceModel::sequence_feature_manager()
const {
return sequence_feature_manager_;
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_MODEL_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_model.h"
#include <string>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::Return;
constexpr int kNumSteps = 50;
constexpr int kVocabularySize = 123;
constexpr int kLinkedDim = 11;
constexpr int kLogitsDim = 17;
constexpr char kLogitsName[] = "oddly_named_logits";
constexpr char kPreviousComponentName[] = "previous_component";
constexpr char kPreviousLayerName[] = "previous_layer";
constexpr float kPreviousLayerValue = -1.0;
// Sequence extractor that extracts [0, 2, 4, ...].
class EvenNumbers : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &) const override {
return true;
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *ids) const override {
ids->clear();
for (int i = 0; i < num_steps_; ++i) ids->push_back(2 * i);
return tensorflow::Status::OK();
}
// Sets the number of steps to emit.
static void SetNumSteps(int num_steps) { num_steps_ = num_steps; }
private:
// The number of steps to produce.
static int num_steps_;
};
int EvenNumbers::num_steps_ = kNumSteps;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(EvenNumbers);
// Trivial linker that links each index to the previous one.
class LinkToPrevious : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &) const override {
return true;
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *links) const override {
links->clear();
for (int i = 0; i < num_steps_; ++i) links->push_back(i - 1);
return tensorflow::Status::OK();
}
// Sets the number of steps to emit.
static void SetNumSteps(int num_steps) { num_steps_ = num_steps; }
private:
// The number of steps to produce.
static int num_steps_;
};
int LinkToPrevious::num_steps_ = kNumSteps;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(LinkToPrevious);
// Trivial predictor that captures the prediction logits.
class CaptureLogits : public SequencePredictor {
public:
// Implements SequenceLinker.
bool Supports(const ComponentSpec &) const override { return true; }
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status Predict(Matrix<float> logits,
InputBatchCache *) const override {
GetLogits() = logits;
return tensorflow::Status::OK();
}
// Returns the captured logits.
static Matrix<float> &GetLogits() {
static auto *logits = new Matrix<float>();
return *logits;
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(CaptureLogits);
class SequenceModelTest : public NetworkTestBase {
protected:
// Adds default call expectations. Since these are added first, they can be
// overridden by call expectations in individual tests.
SequenceModelTest() {
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input_));
EXPECT_CALL(compute_session_, GetReadiedComponent(kTestComponentName))
.WillRepeatedly(Return(&backend_));
// Some tests overwrite these; ensure that they are restored to the normal
// values at the start of each test.
EvenNumbers::SetNumSteps(kNumSteps);
LinkToPrevious::SetNumSteps(kNumSteps);
CaptureLogits::GetLogits() = Matrix<float>();
}
// Initializes the |model_| and its underlying feature managers from the
// |component_spec|, then uses the |model_| to preprocess and predict the
// |input_|. Also sets each row of the logits to twice its row index. On
// error, returns non-OK.
tensorflow::Status Run(ComponentSpec component_spec) {
component_spec.set_name(kTestComponentName);
AddComponent(kPreviousComponentName);
AddLayer(kPreviousLayerName, kLinkedDim);
AddComponent(kTestComponentName);
AddLayer(kLogitsName, kLogitsDim);
TF_RETURN_IF_ERROR(fixed_embedding_manager_.Reset(
component_spec, &variable_store_, &network_state_manager_));
TF_RETURN_IF_ERROR(linked_embedding_manager_.Reset(
component_spec, &variable_store_, &network_state_manager_));
TF_RETURN_IF_ERROR(model_.Initialize(
component_spec, kLogitsName, &fixed_embedding_manager_,
&linked_embedding_manager_, &network_state_manager_));
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
FillLayer(kPreviousComponentName, kPreviousLayerName, kPreviousLayerValue);
StartComponent(0);
TF_RETURN_IF_ERROR(model_.Preprocess(&session_state_, &compute_session_,
&evaluate_state_));
MutableMatrix<float> logits = GetLayer(kTestComponentName, kLogitsName);
for (int row = 0; row < logits.num_rows(); ++row) {
for (int column = 0; column < logits.num_columns(); ++column) {
logits.row(row)[column] = 2.0 * row;
}
}
return model_.Predict(network_states_, &evaluate_state_);
}
// Returns the sequence size passed to the |backend_|.
int GetBackendSequenceSize() {
// The sequence size is not directly exposed, but can be inferred using one
// of the reverse step translators.
return backend_.GetStepLookupFunction("reverse-token")(0, 0, 0) + 1;
}
// Fixed and linked embedding managers.
FixedEmbeddingManager fixed_embedding_manager_;
LinkedEmbeddingManager linked_embedding_manager_;
// Input batch injected into Preprocess() by default.
InputBatchCache input_;
// Backend injected into Preprocess().
SequenceBackend backend_;
// Sequence-based model.
SequenceModel model_;
// Per-evaluation state.
SequenceModel::EvaluateState evaluate_state_;
};
// Returns a ComponentSpec that is supported.
ComponentSpec MakeSupportedSpec() {
ComponentSpec component_spec;
component_spec.set_num_actions(kLogitsDim);
component_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_extractors", "EvenNumbers"});
component_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_linkers", "LinkToPrevious"});
component_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_predictor", "CaptureLogits"});
component_spec.mutable_backend()->set_registered_name("SequenceBackend");
FixedFeatureChannel *fixed_feature = component_spec.add_fixed_feature();
fixed_feature->set_size(1);
fixed_feature->set_embedding_dim(-1);
LinkedFeatureChannel *linked_feature = component_spec.add_linked_feature();
linked_feature->set_source_component(kPreviousComponentName);
linked_feature->set_source_layer(kPreviousLayerName);
linked_feature->set_size(1);
linked_feature->set_embedding_dim(-1);
return component_spec;
}
// Tests that the model supports a supported spec.
TEST_F(SequenceModelTest, Supported) {
const ComponentSpec component_spec = MakeSupportedSpec();
EXPECT_TRUE(SequenceModel::Supports(component_spec));
}
// Tests that the model rejects a spec with the wrong backend.
TEST_F(SequenceModelTest, UnsupportedBackend) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_backend()->set_registered_name("bad");
EXPECT_FALSE(SequenceModel::Supports(component_spec));
}
// Tests that the model rejects a spec with no features.
TEST_F(SequenceModelTest, UnsupportedNoFeatures) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.clear_fixed_feature();
component_spec.clear_linked_feature();
EXPECT_FALSE(SequenceModel::Supports(component_spec));
}
// Tests that the model rejects a spec with a multi-embedding fixed feature.
TEST_F(SequenceModelTest, UnsupportedMultiEmbeddingFixedFeature) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_fixed_feature(0)->set_size(2);
EXPECT_FALSE(SequenceModel::Supports(component_spec));
}
// Tests that the model rejects a spec with a multi-embedding linked feature.
TEST_F(SequenceModelTest, UnsupportedMultiEmbeddingLinkedFeature) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_linked_feature(0)->set_size(2);
EXPECT_FALSE(SequenceModel::Supports(component_spec));
}
// Tests that the model rejects a spec with only recurrent links.
TEST_F(SequenceModelTest, UnsupportedOnlyRecurrentLinks) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.set_name("foo");
component_spec.clear_fixed_feature();
component_spec.mutable_linked_feature(0)->set_source_component("foo");
EXPECT_FALSE(SequenceModel::Supports(component_spec));
}
// Tests that Initialize() succeeds on a supported spec.
TEST_F(SequenceModelTest, InitializeSupported) {
const ComponentSpec component_spec = MakeSupportedSpec();
TF_ASSERT_OK(Run(component_spec));
EXPECT_FALSE(model_.deterministic());
EXPECT_TRUE(model_.left_to_right());
EXPECT_EQ(model_.sequence_feature_manager().num_channels(), 1);
EXPECT_EQ(model_.sequence_link_manager().num_channels(), 1);
}
// Tests that Initialize() detects deterministic components.
TEST_F(SequenceModelTest, InitializeDeterministic) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.set_num_actions(1);
TF_ASSERT_OK(Run(component_spec));
EXPECT_TRUE(model_.deterministic());
EXPECT_TRUE(model_.left_to_right());
EXPECT_EQ(model_.sequence_feature_manager().num_channels(), 1);
EXPECT_EQ(model_.sequence_link_manager().num_channels(), 1);
}
// Tests that Initialize() detects right-to-left components.
TEST_F(SequenceModelTest, InitializeLeftToRight) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_transition_system()->mutable_parameters()->insert(
{"left_to_right", "false"});
TF_ASSERT_OK(Run(component_spec));
EXPECT_FALSE(model_.deterministic());
EXPECT_FALSE(model_.left_to_right());
EXPECT_EQ(model_.sequence_feature_manager().num_channels(), 1);
EXPECT_EQ(model_.sequence_link_manager().num_channels(), 1);
}
// Tests that Initialize() fails if the backend is wrong.
TEST_F(SequenceModelTest, WrongBackend) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_backend()->set_registered_name("bad");
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("Invalid component backend"));
}
// Tests that Initialize() fails if the number of actions in the ComponentSpec
// does not match the logits.
TEST_F(SequenceModelTest, WrongNumActions) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.set_num_actions(kLogitsDim + 1);
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("Logits dimension mismatch"));
}
// Tests that Initialize() fails if an unknown sequence extractor is specified.
TEST_F(SequenceModelTest, UnknownSequenceExtractor) {
ComponentSpec component_spec = MakeSupportedSpec();
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_extractors"] = "bad";
EXPECT_THAT(
Run(component_spec),
test::IsErrorWithSubstr("Unknown DRAGNN Runtime Sequence Extractor"));
}
// Tests that Initialize() fails if an unknown sequence linker is specified.
TEST_F(SequenceModelTest, UnknownSequenceLinker) {
ComponentSpec component_spec = MakeSupportedSpec();
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_linkers"] = "bad";
EXPECT_THAT(
Run(component_spec),
test::IsErrorWithSubstr("Unknown DRAGNN Runtime Sequence Linker"));
}
// Tests that Initialize() fails if an unknown sequence predictor is specified.
TEST_F(SequenceModelTest, UnknownSequencePredictor) {
ComponentSpec component_spec = MakeSupportedSpec();
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_predictor"] = "bad";
EXPECT_THAT(
Run(component_spec),
test::IsErrorWithSubstr("Unknown DRAGNN Runtime Sequence Predictor"));
}
// Tests that Initialize() fails on an unknown component builder parameter.
TEST_F(SequenceModelTest, UnknownComponentBuilderParameter) {
ComponentSpec component_spec = MakeSupportedSpec();
(*component_spec.mutable_component_builder()->mutable_parameters())["bad"] =
"bad";
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("Unknown attribute"));
}
// Tests that Initialize() fails if there are no fixed or linked features.
TEST_F(SequenceModelTest, InitializeRequiresFeatures) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.clear_fixed_feature();
component_spec.clear_linked_feature();
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_extractors"] = "";
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_linkers"] = "";
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("No fixed or linked features"));
}
// Tests that the model fails if a null batch is returned.
TEST_F(SequenceModelTest, NullBatch) {
EXPECT_CALL(compute_session_, GetInputBatchCache()).WillOnce(Return(nullptr));
EXPECT_THAT(Run(MakeSupportedSpec()),
test::IsErrorWithSubstr("Null input batch"));
}
// Tests that the model properly sets up the EvaluateState and logits.
TEST_F(SequenceModelTest, Success) {
TF_ASSERT_OK(Run(MakeSupportedSpec()));
EXPECT_EQ(GetBackendSequenceSize(), kNumSteps);
EXPECT_EQ(evaluate_state_.num_steps, kNumSteps);
EXPECT_EQ(evaluate_state_.input, &input_);
EXPECT_EQ(evaluate_state_.features.num_channels(), 1);
EXPECT_EQ(evaluate_state_.features.num_steps(), kNumSteps);
EXPECT_EQ(evaluate_state_.features.GetId(0, 0), 0);
EXPECT_EQ(evaluate_state_.features.GetId(0, 1), 2);
EXPECT_EQ(evaluate_state_.features.GetId(0, 2), 4);
EXPECT_EQ(evaluate_state_.links.num_channels(), 1);
EXPECT_EQ(evaluate_state_.links.num_steps(), kNumSteps);
Vector<float> embedding;
bool is_out_of_bounds = false;
evaluate_state_.links.Get(0, 0, &embedding, &is_out_of_bounds);
ExpectVector(embedding, kLinkedDim, 0.0);
EXPECT_TRUE(is_out_of_bounds);
evaluate_state_.links.Get(0, 1, &embedding, &is_out_of_bounds);
ExpectVector(embedding, kLinkedDim, kPreviousLayerValue);
EXPECT_FALSE(is_out_of_bounds);
const Matrix<float> logits = CaptureLogits::GetLogits();
ASSERT_EQ(logits.num_rows(), kNumSteps);
ASSERT_EQ(logits.num_columns(), kLogitsDim);
for (int i = 0; i < kNumSteps; ++i) {
ExpectVector(logits.row(i), kLogitsDim, 2.0 * i);
}
}
// Tests that the model works with only fixed features.
TEST_F(SequenceModelTest, FixedFeaturesOnly) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.clear_linked_feature();
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_linkers"] = "";
TF_ASSERT_OK(Run(component_spec));
EXPECT_EQ(GetBackendSequenceSize(), kNumSteps);
EXPECT_EQ(evaluate_state_.num_steps, kNumSteps);
EXPECT_EQ(evaluate_state_.input, &input_);
EXPECT_EQ(evaluate_state_.features.num_channels(), 1);
EXPECT_EQ(evaluate_state_.features.num_steps(), kNumSteps);
EXPECT_EQ(evaluate_state_.features.GetId(0, 0), 0);
EXPECT_EQ(evaluate_state_.features.GetId(0, 1), 2);
EXPECT_EQ(evaluate_state_.features.GetId(0, 2), 4);
EXPECT_EQ(evaluate_state_.links.num_channels(), 0);
EXPECT_EQ(evaluate_state_.links.num_steps(), 0);
const Matrix<float> logits = CaptureLogits::GetLogits();
ASSERT_EQ(logits.num_rows(), kNumSteps);
ASSERT_EQ(logits.num_columns(), kLogitsDim);
for (int i = 0; i < kNumSteps; ++i) {
ExpectVector(logits.row(i), kLogitsDim, 2.0 * i);
}
}
// Tests that the model works with only linked features.
TEST_F(SequenceModelTest, LinkedFeaturesOnly) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.clear_fixed_feature();
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_extractors"] = "";
TF_ASSERT_OK(Run(component_spec));
EXPECT_EQ(GetBackendSequenceSize(), kNumSteps);
EXPECT_EQ(evaluate_state_.num_steps, kNumSteps);
EXPECT_EQ(evaluate_state_.input, &input_);
EXPECT_EQ(evaluate_state_.features.num_channels(), 0);
EXPECT_EQ(evaluate_state_.features.num_steps(), 0);
EXPECT_EQ(evaluate_state_.links.num_channels(), 1);
EXPECT_EQ(evaluate_state_.links.num_steps(), kNumSteps);
Vector<float> embedding;
bool is_out_of_bounds = false;
evaluate_state_.links.Get(0, 0, &embedding, &is_out_of_bounds);
ExpectVector(embedding, kLinkedDim, 0.0);
EXPECT_TRUE(is_out_of_bounds);
evaluate_state_.links.Get(0, 1, &embedding, &is_out_of_bounds);
ExpectVector(embedding, kLinkedDim, kPreviousLayerValue);
EXPECT_FALSE(is_out_of_bounds);
const Matrix<float> logits = CaptureLogits::GetLogits();
ASSERT_EQ(logits.num_rows(), kNumSteps);
ASSERT_EQ(logits.num_columns(), kLogitsDim);
for (int i = 0; i < kNumSteps; ++i) {
ExpectVector(logits.row(i), kLogitsDim, 2.0 * i);
}
}
// Tests that the model fails if the fixed and linked features disagree on the
// number of steps.
TEST_F(SequenceModelTest, FixedAndLinkedDisagree) {
EvenNumbers::SetNumSteps(5);
LinkToPrevious::SetNumSteps(6);
EXPECT_THAT(Run(MakeSupportedSpec()),
test::IsErrorWithSubstr("Sequence length mismatch between fixed "
"features (5) and linked features (6)"));
}
// Tests that the model can handle an empty sequence.
TEST_F(SequenceModelTest, EmptySequence) {
EvenNumbers::SetNumSteps(0);
LinkToPrevious::SetNumSteps(0);
TF_ASSERT_OK(Run(MakeSupportedSpec()));
EXPECT_EQ(GetBackendSequenceSize(), 0);
const Matrix<float> logits = CaptureLogits::GetLogits();
ASSERT_EQ(logits.num_rows(), 0);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_predictor.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status SequencePredictor::Select(
const ComponentSpec &component_spec, string *name) {
string supporting_name;
for (const Registry::Registrar *registrar = registry()->components;
registrar != nullptr; registrar = registrar->next()) {
Factory *factory_function = registrar->object();
std::unique_ptr<SequencePredictor> current_predictor(factory_function());
if (!current_predictor->Supports(component_spec)) continue;
if (!supporting_name.empty()) {
return tensorflow::errors::Internal(
"Multiple SequencePredictors support ComponentSpec (",
supporting_name, " and ", registrar->name(),
"): ", component_spec.ShortDebugString());
}
supporting_name = registrar->name();
}
if (supporting_name.empty()) {
return tensorflow::errors::NotFound(
"No SequencePredictor supports ComponentSpec: ",
component_spec.ShortDebugString());
}
// Success; make modifications.
*name = supporting_name;
return tensorflow::Status::OK();
}
tensorflow::Status SequencePredictor::New(
const string &name, const ComponentSpec &component_spec,
std::unique_ptr<SequencePredictor> *predictor) {
std::unique_ptr<SequencePredictor> matching_predictor;
TF_RETURN_IF_ERROR(
SequencePredictor::CreateOrError(name, &matching_predictor));
TF_RETURN_IF_ERROR(matching_predictor->Initialize(component_spec));
// Success; make modifications.
*predictor = std::move(matching_predictor);
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Sequence Predictor",
dragnn::runtime::SequencePredictor);
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_PREDICTOR_H_
#define DRAGNN_RUNTIME_SEQUENCE_PREDICTOR_H_
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/math/types.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Interface for making predictions on sequences.
//
// This predictor can be used to avoid ComputeSession overhead in simple cases;
// for example, predicting sequences of POS tags.
class SequencePredictor : public RegisterableClass<SequencePredictor> {
public:
// Sets |predictor| to an instance of the subclass named |name| initialized
// from the |component_spec|. On error, returns non-OK and modifies nothing.
static tensorflow::Status New(const string &name,
const ComponentSpec &component_spec,
std::unique_ptr<SequencePredictor> *predictor);
SequencePredictor(const SequencePredictor &) = delete;
SequencePredictor &operator=(const SequencePredictor &) = delete;
virtual ~SequencePredictor() = default;
// Sets |name| to the registered name of the SequencePredictor that supports
// the |component_spec|. On error, returns non-OK and modifies nothing. The
// returned statuses include:
// * OK: If a supporting SequencePredictor was found.
// * INTERNAL: If an error occurred while searching for a compatible match.
// * NOT_FOUND: If the search was error-free, but no compatible match was
// found.
static tensorflow::Status Select(const ComponentSpec &component_spec,
string *name);
// Makes a sequence of predictions using the per-step |logits| and writes
// annotations to the |input|.
virtual tensorflow::Status Predict(Matrix<float> logits,
InputBatchCache *input) const = 0;
protected:
SequencePredictor() = default;
private:
// Helps prevent use of the Create() method; use New() instead.
using RegisterableClass<SequencePredictor>::Create;
// Returns true if this supports the |component_spec|. Implementations must
// coordinate to ensure that at most one supports any given |component_spec|.
virtual bool Supports(const ComponentSpec &component_spec) const = 0;
// Initializes this from the |component_spec|. On error, returns non-OK.
virtual tensorflow::Status Initialize(
const ComponentSpec &component_spec) = 0;
};
} // namespace runtime
} // namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Sequence Predictor",
dragnn::runtime::SequencePredictor);
} // namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::SequencePredictor, #subclass, subclass)
#endif // DRAGNN_RUNTIME_SEQUENCE_PREDICTOR_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_predictor.h"
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/math/types.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Supports components named "success" and initializes successfully.
class Success : public SequencePredictor {
public:
// Implements SequencePredictor.
bool Supports(const ComponentSpec &component_spec) const override {
return component_spec.name() == "success";
}
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status Predict(Matrix<float>, InputBatchCache *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(Success);
// Supports components named "failure" and fails to initialize.
class Failure : public SequencePredictor {
public:
// Implements SequencePredictor.
bool Supports(const ComponentSpec &component_spec) const override {
return component_spec.name() == "failure";
}
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::errors::Internal("Boom!");
}
tensorflow::Status Predict(Matrix<float>, InputBatchCache *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(Failure);
// Supports components named "duplicate" and initializes successfully.
class Duplicate : public SequencePredictor {
public:
// Implements SequencePredictor.
bool Supports(const ComponentSpec &component_spec) const override {
return component_spec.name() == "duplicate";
}
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status Predict(Matrix<float>, InputBatchCache *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(Duplicate);
// Duplicate of the above.
using Duplicate2 = Duplicate;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(Duplicate2);
// Tests that a component can be successfully created.
TEST(SequencePredictorTest, Success) {
string name;
std::unique_ptr<SequencePredictor> predictor;
ComponentSpec component_spec;
component_spec.set_name("success");
TF_ASSERT_OK(SequencePredictor::Select(component_spec, &name));
ASSERT_EQ(name, "Success");
TF_EXPECT_OK(SequencePredictor::New(name, component_spec, &predictor));
EXPECT_NE(predictor, nullptr);
}
// Tests that errors in Initialize() are reported.
TEST(SequencePredictorTest, FailToInitialize) {
string name;
std::unique_ptr<SequencePredictor> predictor;
ComponentSpec component_spec;
component_spec.set_name("failure");
TF_ASSERT_OK(SequencePredictor::Select(component_spec, &name));
EXPECT_EQ(name, "Failure");
EXPECT_THAT(SequencePredictor::New(name, component_spec, &predictor),
test::IsErrorWithSubstr("Boom!"));
EXPECT_EQ(predictor, nullptr);
}
// Tests that unsupported specs are reported as NOT_FOUND errors.
TEST(SequencePredictorTest, UnsupportedSpec) {
string name = "not overwritten";
ComponentSpec component_spec;
component_spec.set_name("unsupported");
EXPECT_THAT(SequencePredictor::Select(component_spec, &name),
test::IsErrorWithCodeAndSubstr(
tensorflow::error::NOT_FOUND,
"No SequencePredictor supports ComponentSpec"));
EXPECT_EQ(name, "not overwritten");
}
// Tests that unsupported subclass names are reported as errors.
TEST(SequencePredictorTest, UnsupportedSubclass) {
std::unique_ptr<SequencePredictor> predictor;
ComponentSpec component_spec;
EXPECT_THAT(
SequencePredictor::New("Unsupported", component_spec, &predictor),
test::IsErrorWithSubstr("Unknown DRAGNN Runtime Sequence Predictor"));
EXPECT_EQ(predictor, nullptr);
}
// Tests that multiple supporting predictors are reported as INTERNAL errors.
TEST(SequencePredictorTest, Duplicate) {
string name = "not overwritten";
ComponentSpec component_spec;
component_spec.set_name("duplicate");
EXPECT_THAT(SequencePredictor::Select(component_spec, &name),
test::IsErrorWithCodeAndSubstr(
tensorflow::error::INTERNAL,
"Multiple SequencePredictors support ComponentSpec"));
EXPECT_EQ(name, "not overwritten");
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment