Commit edea2b67 authored by Terry Koo's avatar Terry Koo
Browse files

Remove runtime because reasons.

parent a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/extensions.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Base class for test components.
class TestComponentBase : public Component {
public:
// Partially implements Component.
tensorflow::Status Initialize(const ComponentSpec &, VariableStore *,
NetworkStateManager *,
ExtensionManager *) override {
return tensorflow::Status::OK();
}
tensorflow::Status Evaluate(SessionState *, ComputeSession *,
ComponentTrace *) const override {
return tensorflow::Status::OK();
}
bool PreferredTo(const Component &) const override { return false; }
};
// Supports components whose builder name includes "Foo".
class ContainsFoo : public TestComponentBase {
public:
// Implements Component.
bool Supports(const ComponentSpec &,
const string &normalized_builder_name) const override {
return normalized_builder_name.find("Foo") != string::npos;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT(ContainsFoo);
// Supports components whose builder name includes "Bar".
class ContainsBar : public TestComponentBase {
public:
// Implements Component.
bool Supports(const ComponentSpec &,
const string &normalized_builder_name) const override {
return normalized_builder_name.find("Bar") != string::npos;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT(ContainsBar);
// Tests that a spec with an unknown builder name causes an error.
TEST(SelectBestComponentTransformerTest, Unknown) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("unknown");
EXPECT_THAT(ComponentTransformer::ApplyAll(&component_spec),
test::IsErrorWithSubstr("Could not find a best"));
}
// Tests that a spec with builder "Foo" is changed to "ContainsFoo".
TEST(SelectBestComponentTransformerTest, ChangeToContainsFoo) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("Foo");
ComponentSpec expected_spec = component_spec;
expected_spec.mutable_component_builder()->set_registered_name("ContainsFoo");
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(expected_spec));
}
// Tests that a spec with builder "Bar" is changed to "ContainsBar".
TEST(SelectBestComponentTransformerTest, ChangeToContainsBar) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("Bar");
ComponentSpec expected_spec = component_spec;
expected_spec.mutable_component_builder()->set_registered_name("ContainsBar");
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(expected_spec));
}
// Tests that a spec with builder "FooBar" causes a conflict.
TEST(SelectBestComponentTransformerTest, Conflict) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("FooBar");
EXPECT_THAT(
ComponentTransformer::ApplyAll(&component_spec),
test::IsErrorWithSubstr("both think they should be dis-preferred"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/core/component_registry.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
std::function<int(int, int, int)> SequenceBackend::GetStepLookupFunction(
const string &method) {
if (method == "reverse-char" || method == "reverse-token") {
// Reverses the |index| in the sequence. We are agnostic to whether the
// input is a sequence of tokens or chars.
return [this](int unused_batch_index, int unused_beam_index, int index) {
index = sequence_size_ - index - 1;
return index >= 0 && index < sequence_size_ ? index : -1;
};
}
LOG(FATAL) << "[" << name_ << "] Unknown step lookup function: " << method;
}
void SequenceBackend::InitializeComponent(const ComponentSpec &spec) {
name_ = spec.name();
}
void SequenceBackend::InitializeData(
const std::vector<std::vector<const TransitionState *>> &parent_states,
int max_beam_size, InputBatchCache *input_data) {
// Store the |parent_states| for forwarding to downstream components.
parent_states_ = parent_states;
}
std::vector<std::vector<const TransitionState *>> SequenceBackend::GetBeam() {
// Forward the states of the previous component.
return parent_states_;
}
int SequenceBackend::GetSourceBeamIndex(int current_index, int batch) const {
// Forward the |current_index| to the previous component.
return current_index;
}
int SequenceBackend::GetBeamIndexAtStep(int step, int current_index,
int batch) const {
// Always return 0 since there is only one beam.
return 0;
}
std::vector<std::vector<ComponentTrace>> SequenceBackend::GetTraceProtos()
const {
// Return a single trace, since the beam and batch sizes are fixed at 1.
return {{ComponentTrace()}};
}
string SequenceBackend::Name() const { return name_; }
int SequenceBackend::BeamSize() const { return 1; }
int SequenceBackend::BatchSize() const { return 1; }
bool SequenceBackend::IsReady() const { return true; }
bool SequenceBackend::IsTerminal() const { return true; }
void SequenceBackend::FinalizeData() {}
void SequenceBackend::ResetComponent() {}
void SequenceBackend::InitializeTracing() {}
void SequenceBackend::DisableTracing() {}
int SequenceBackend::StepsTaken(int batch_index) const {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
bool SequenceBackend::AdvanceFromPrediction(const float *transition_matrix,
int num_items, int num_actions) {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
void SequenceBackend::AdvanceFromOracle() {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
std::vector<std::vector<std::vector<Label>>> SequenceBackend::GetOracleLabels()
const {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
int SequenceBackend::GetFixedFeatures(
std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights, int channel_id) const {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
int SequenceBackend::BulkGetFixedFeatures(
const BulkFeatureExtractor &extractor) {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
void SequenceBackend::BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
void SequenceBackend::BulkEmbedDenseFixedFeatures(
const vector<const float *> &per_channel_embeddings,
float *embedding_output, int embedding_output_size,
int *offset_array_output, int offset_array_size) {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
int SequenceBackend::BulkDenseFeatureSize() const {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
std::vector<LinkFeatures> SequenceBackend::GetRawLinkFeatures(
int channel_id) const {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
void SequenceBackend::AddTranslatedLinkFeaturesToTrace(
const std::vector<LinkFeatures> &features, int channel_id) {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
REGISTER_DRAGNN_COMPONENT(SequenceBackend);
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_BACKEND_H_
#define DRAGNN_RUNTIME_SEQUENCE_BACKEND_H_
#include <functional>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "syntaxnet/base.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Runtime-only component backend for sequence-based models. This is not used
// at training time, and provides trivial implementations of most methods. This
// is intended to be used with alternative feature extraction approaches, such
// as SequenceExtractor.
class SequenceBackend : public dragnn::Component {
public:
// Sets the size of the sequence in the current input.
void SetSequenceSize(int size) { sequence_size_ = size; }
// Implements dragnn::Component.
std::function<int(int, int, int)> GetStepLookupFunction(
const string &method) override;
void InitializeComponent(const ComponentSpec &spec) override;
void InitializeData(
const std::vector<std::vector<const TransitionState *>> &parent_states,
int max_beam_size, InputBatchCache *input_data) override;
std::vector<std::vector<const TransitionState *>> GetBeam() override;
int GetSourceBeamIndex(int current_index, int batch) const override;
int GetBeamIndexAtStep(int step, int current_index, int batch) const override;
std::vector<std::vector<ComponentTrace>> GetTraceProtos() const override;
string Name() const override;
int BeamSize() const override;
int BatchSize() const override;
bool IsReady() const override;
bool IsTerminal() const override;
void FinalizeData() override;
void ResetComponent() override;
void InitializeTracing() override;
void DisableTracing() override;
// Not implemented, crashes when called.
int StepsTaken(int batch_index) const override;
// Not implemented, crashes when called.
bool AdvanceFromPrediction(const float *transition_matrix, int num_items,
int num_actions) override;
// Not implemented, crashes when called.
void AdvanceFromOracle() override;
// Not implemented, crashes when called.
std::vector<std::vector<std::vector<Label>>> GetOracleLabels() const override;
// Not implemented, crashes when called.
int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights,
int channel_id) const override;
// Not implemented, crashes when called.
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override;
// Not implemented, crashes when called.
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override;
// Not implemented, crashes when called.
void BulkEmbedDenseFixedFeatures(
const vector<const float *> &per_channel_embeddings,
float *embedding_output, int embedding_output_size,
int *offset_array_output, int offset_array_size) override;
// Not implemented, crashes when called.
int BulkDenseFeatureSize() const override;
// Not implemented, crashes when called.
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override;
// Not implemented, crashes when called.
void AddTranslatedLinkFeaturesToTrace(
const std::vector<LinkFeatures> &features, int channel_id) override;
private:
// Name of the component that this backend supports.
string name_;
// Size of the current input sequence.
int sequence_size_ = 0;
// Parent states passed to InitializeData(), and passed along in GetBeam().
std::vector<std::vector<const TransitionState *>> parent_states_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_BACKEND_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Tests that the "reverse-*" step lookup functions ignore the batch and beam
// indices and return -1 if the sequence size was never set.
TEST(SequenceBackendTest, ReverseCharUninitialized) {
for (const string &reverse_method : {"reverse-char", "reverse-token"}) {
SequenceBackend backend;
const std::function<int(int, int, int)> reverse =
backend.GetStepLookupFunction(reverse_method);
EXPECT_EQ(reverse(0, 0, 0), -1);
EXPECT_EQ(reverse(1, 1, 1), -1);
EXPECT_EQ(reverse(-1, -1, -1), -1);
EXPECT_EQ(reverse(0, 0, 9999), -1);
EXPECT_EQ(reverse(0, 0, -9999), -1);
}
}
// Tests that the "reverse-*" step lookup functions ignore the batch and beam
// indices and return the reverse of the step index w.r.t. the most recent call
// to SetSequenceSize().
TEST(SequenceBackendTest, ReverseCharAfterSetSequenceSize) {
for (const string &reverse_method : {"reverse-char", "reverse-token"}) {
SequenceBackend backend;
const std::function<int(int, int, int)> reverse =
backend.GetStepLookupFunction(reverse_method);
backend.SetSequenceSize(10);
EXPECT_EQ(reverse(0, 0, -1), -1);
EXPECT_EQ(reverse(0, 0, 0), 9);
EXPECT_EQ(reverse(1, 1, 1), 8);
EXPECT_EQ(reverse(8, 8, 8), 1);
EXPECT_EQ(reverse(9, 9, 9), 0);
EXPECT_EQ(reverse(10, 10, 10), -1);
EXPECT_EQ(reverse(-1, -1, 5), 4);
EXPECT_EQ(reverse(0, 0, 9999), -1);
EXPECT_EQ(reverse(0, 0, -9999), -1);
backend.SetSequenceSize(11);
EXPECT_EQ(reverse(0, 0, -1), -1);
EXPECT_EQ(reverse(0, 0, 0), 10);
EXPECT_EQ(reverse(1, 1, 1), 9);
EXPECT_EQ(reverse(8, 8, 8), 2);
EXPECT_EQ(reverse(9, 9, 9), 1);
EXPECT_EQ(reverse(10, 10, 10), 0);
EXPECT_EQ(reverse(-1, -1, 5), 5);
EXPECT_EQ(reverse(0, 0, 9999), -1);
EXPECT_EQ(reverse(0, 0, -9999), -1);
}
}
// Tests that the input beam is forwarded.
TEST(SequenceBackendTest, BeamForwarding) {
SequenceBackend backend;
const TransitionState *parent_state = nullptr;
parent_state += 1234; // arbitrary non-null pointer
const std::vector<std::vector<const TransitionState *>> parent_states = {
{parent_state}};
const int ignored_max_beam_size = 999;
InputBatchCache *ignored_input = nullptr;
backend.InitializeData(parent_states, ignored_max_beam_size, ignored_input);
EXPECT_EQ(backend.GetBeam(), parent_states);
}
// Tests the accessors of the backend.
TEST(SequenceBackendTest, Accessors) {
SequenceBackend backend;
ComponentSpec spec;
spec.set_name("foo");
backend.InitializeComponent(spec);
EXPECT_EQ(backend.Name(), "foo");
EXPECT_EQ(backend.BeamSize(), 1);
EXPECT_EQ(backend.BatchSize(), 1);
EXPECT_TRUE(backend.IsReady());
EXPECT_TRUE(backend.IsTerminal());
}
// Tests the trivial mutators of the backend.
TEST(SequenceBackendTest, Mutators) {
SequenceBackend backend;
// These are NOPs and should not crash.
backend.FinalizeData();
backend.ResetComponent();
backend.InitializeTracing();
backend.DisableTracing();
}
// Tests the beam index accessors of the backend.
TEST(SequenceBackendTest, BeamIndex) {
SequenceBackend backend;
// This always returns the current_index (first arg).
EXPECT_EQ(backend.GetSourceBeamIndex(0, 0), 0);
EXPECT_EQ(backend.GetSourceBeamIndex(1, 2), 1);
EXPECT_EQ(backend.GetSourceBeamIndex(-1, -1), -1);
EXPECT_EQ(backend.GetSourceBeamIndex(10, 99), 10);
// This always returns 0.
EXPECT_EQ(backend.GetBeamIndexAtStep(0, 0, 0), 0);
EXPECT_EQ(backend.GetBeamIndexAtStep(1, 2, 3), 0);
EXPECT_EQ(backend.GetBeamIndexAtStep(-1, -1, -1), 0);
EXPECT_EQ(backend.GetBeamIndexAtStep(123, 456, 789), 0);
}
// Tests the that the backend produces a single empty trace.
TEST(SequenceBackendTest, Tracing) {
SequenceBackend backend;
const ComponentTrace empty_trace;
const auto actual_traces = backend.GetTraceProtos();
ASSERT_EQ(actual_traces.size(), 1);
ASSERT_EQ(actual_traces[0].size(), 1);
EXPECT_THAT(actual_traces[0][0], test::EqualsProto(empty_trace));
}
// Tests the unsupported methods of the backend.
TEST(SequenceBackendTest, UnsupportedMethods) {
SequenceBackend backend;
EXPECT_DEATH(backend.StepsTaken(0), "Not supported");
EXPECT_DEATH(backend.AdvanceFromPrediction(nullptr, 0, 0), "Not supported");
EXPECT_DEATH(backend.AdvanceFromOracle(), "Not supported");
EXPECT_DEATH(backend.GetOracleLabels(), "Not supported");
EXPECT_DEATH(backend.GetFixedFeatures(nullptr, nullptr, nullptr, 0),
"Not supported");
EXPECT_DEATH(backend.BulkGetFixedFeatures(
BulkFeatureExtractor(nullptr, nullptr, nullptr)),
"Not supported");
EXPECT_DEATH(backend.BulkEmbedFixedFeatures(0, 0, 0, {}, nullptr),
"Not supported");
EXPECT_DEATH(backend.GetRawLinkFeatures(0), "Not supported");
EXPECT_DEATH(backend.AddTranslatedLinkFeaturesToTrace({}, 0),
"Not supported");
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string.h>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_model.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Sequence-based bulk version of DynamicComponent.
class SequenceBulkDynamicComponent : public Component {
public:
// Implements Component.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override;
bool Supports(const ComponentSpec &component_spec,
const string &normalized_builder_name) const override;
bool PreferredTo(const Component &other) const override { return false; }
private:
// Evaluates all input features in the |state|, concatenates them into a
// matrix of inputs in the |network_states|, and returns the matrix.
Matrix<float> EvaluateInputs(const SequenceModel::EvaluateState &state,
const NetworkStates &network_states) const;
// Managers for input embeddings.
FixedEmbeddingManager fixed_embedding_manager_;
LinkedEmbeddingManager linked_embedding_manager_;
// Sequence-based model evaluator.
SequenceModel sequence_model_;
// Network unit for bulk inference.
std::unique_ptr<BulkNetworkUnit> bulk_network_unit_;
// Concatenated input matrix.
LocalMatrixHandle<float> inputs_handle_;
// Intermediate values used by sequence models.
SharedExtensionHandle<SequenceModel::EvaluateState> evaluate_state_handle_;
};
bool SequenceBulkDynamicComponent::Supports(
const ComponentSpec &component_spec,
const string &normalized_builder_name) const {
// Require embedded fixed features.
for (const FixedFeatureChannel &channel : component_spec.fixed_feature()) {
if (channel.embedding_dim() < 0) return false;
}
// Require non-transformed and non-recurrent linked features.
// TODO(googleuser): Make SequenceLinks support transformed linked features?
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.embedding_dim() >= 0) return false;
if (channel.source_component() == component_spec.name()) return false;
}
return normalized_builder_name == "SequenceBulkDynamicComponent" &&
SequenceModel::Supports(component_spec);
}
// Returns the sum of the dimensions of all channels in the |manager|.
template <class EmbeddingManager>
size_t SumEmbeddingDimensions(const EmbeddingManager &manager) {
size_t sum = 0;
for (size_t i = 0; i < manager.num_channels(); ++i) {
sum += manager.embedding_dim(i);
}
return sum;
}
tensorflow::Status SequenceBulkDynamicComponent::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
TF_RETURN_IF_ERROR(BulkNetworkUnit::CreateOrError(
BulkNetworkUnit::GetClassName(component_spec), &bulk_network_unit_));
TF_RETURN_IF_ERROR(
bulk_network_unit_->Initialize(component_spec, variable_store,
network_state_manager, extension_manager));
TF_RETURN_IF_ERROR(fixed_embedding_manager_.Reset(
component_spec, variable_store, network_state_manager));
TF_RETURN_IF_ERROR(linked_embedding_manager_.Reset(
component_spec, variable_store, network_state_manager));
const size_t concatenated_input_dim =
SumEmbeddingDimensions(fixed_embedding_manager_) +
SumEmbeddingDimensions(linked_embedding_manager_);
TF_RETURN_IF_ERROR(
bulk_network_unit_->ValidateInputDimension(concatenated_input_dim));
TF_RETURN_IF_ERROR(
network_state_manager->AddLocal(concatenated_input_dim, &inputs_handle_));
TF_RETURN_IF_ERROR(sequence_model_.Initialize(
component_spec, bulk_network_unit_->GetLogitsName(),
&fixed_embedding_manager_, &linked_embedding_manager_,
network_state_manager));
extension_manager->GetShared(&evaluate_state_handle_);
return tensorflow::Status::OK();
}
tensorflow::Status SequenceBulkDynamicComponent::Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const {
const NetworkStates &network_states = session_state->network_states;
SequenceModel::EvaluateState &state =
session_state->extensions.Get(evaluate_state_handle_);
TF_RETURN_IF_ERROR(
sequence_model_.Preprocess(session_state, compute_session, &state));
const Matrix<float> inputs = EvaluateInputs(state, network_states);
TF_RETURN_IF_ERROR(bulk_network_unit_->Evaluate(inputs, session_state));
return sequence_model_.Predict(network_states, &state);
}
Matrix<float> SequenceBulkDynamicComponent::EvaluateInputs(
const SequenceModel::EvaluateState &state,
const NetworkStates &network_states) const {
const MutableMatrix<float> inputs = network_states.GetLocal(inputs_handle_);
// Declared here for reuse in the loop below.
bool is_out_of_bounds = false;
Vector<float> embedding;
// Handle forward and reverse iteration via a start index and increment.
int target_index = sequence_model_.left_to_right() ? 0 : state.num_steps - 1;
const int target_increment = sequence_model_.left_to_right() ? 1 : -1;
for (size_t step_index = 0; step_index < state.num_steps;
++step_index, target_index += target_increment) {
const MutableVector<float> row = inputs.row(step_index);
float *output = row.data();
for (size_t channel_id = 0; channel_id < state.features.num_channels();
++channel_id) {
embedding = state.features.GetEmbedding(channel_id, target_index);
memcpy(output, embedding.data(), embedding.size() * sizeof(float));
output += embedding.size();
}
for (size_t channel_id = 0; channel_id < state.links.num_channels();
++channel_id) {
state.links.Get(channel_id, target_index, &embedding, &is_out_of_bounds);
memcpy(output, embedding.data(), embedding.size() * sizeof(float));
output += embedding.size();
}
DCHECK_EQ(output, row.end());
}
return Matrix<float>(inputs);
}
DRAGNN_RUNTIME_REGISTER_COMPONENT(SequenceBulkDynamicComponent);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::Return;
constexpr size_t kNumSteps = 50;
constexpr size_t kFixedDim = 11;
constexpr size_t kFixedVocabularySize = 123;
constexpr float kFixedValue = 0.5;
constexpr size_t kLinkedDim = 13;
constexpr float kLinkedValue = 1.25;
constexpr char kPreviousComponentName[] = "previous_component";
constexpr char kPreviousLayerName[] = "previous_layer";
constexpr char kLogitsName[] = "logits";
constexpr size_t kLogitsDim = kFixedDim + kLinkedDim;
// Adds one to all inputs.
class BulkAddOne : public BulkNetworkUnit {
public:
// Implements BulkNetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return network_state_manager->AddLayer(kLogitsName, kLogitsDim,
&logits_handle_);
}
tensorflow::Status ValidateInputDimension(size_t dimension) const override {
return tensorflow::Status::OK();
}
string GetLogitsName() const override { return kLogitsName; }
tensorflow::Status Evaluate(Matrix<float> inputs,
SessionState *session_state) const override {
const MutableMatrix<float> logits =
session_state->network_states.GetLayer(logits_handle_);
for (size_t row = 0; row < inputs.num_rows(); ++row) {
for (size_t column = 0; column < inputs.num_columns(); ++column) {
logits.row(row)[column] = inputs.row(row)[column] + 1.0;
}
}
return tensorflow::Status::OK();
}
private:
// Output logits.
LayerHandle<float> logits_handle_;
};
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT(BulkAddOne);
// A component that also prefers other but is triggered on the presence of a
// resource. This can be used to cause a component selection conflict.
class ImTheWorst : public Component {
public:
// Implements Component.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return tensorflow::Status::OK();
}
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override {
return tensorflow::Status::OK();
}
bool Supports(const ComponentSpec &component_spec,
const string &normalized_builder_name) const override {
return component_spec.resource_size() > 0;
}
bool PreferredTo(const Component &other) const override { return false; }
};
DRAGNN_RUNTIME_REGISTER_COMPONENT(ImTheWorst);
// Extractor that produces a sequence of zeros.
class ExtractZeros : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &) const override {
return true;
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *ids) const override {
ids->assign(kNumSteps, 0);
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(ExtractZeros);
// Linker that produces a sequence of zeros.
class LinkZeros : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &) const override {
return true;
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *links) const override {
links->assign(kNumSteps, 0);
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(LinkZeros);
// Predictor that captures the logits.
class CaptureLogits : public SequencePredictor {
public:
// Implements SequencePredictor.
bool Supports(const ComponentSpec &) const override { return true; }
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status Predict(Matrix<float> logits,
InputBatchCache *) const override {
logits_ = logits;
return tensorflow::Status::OK();
}
// Returns the captured logits.
static Matrix<float> GetCapturedLogits() { return logits_; }
private:
// Logits from the most recent call to Predict().
static Matrix<float> logits_;
};
Matrix<float> CaptureLogits::logits_;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(CaptureLogits);
class SequenceBulkDynamicComponentTest : public NetworkTestBase {
protected:
SequenceBulkDynamicComponentTest() {
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input_));
EXPECT_CALL(compute_session_, GetReadiedComponent(kTestComponentName))
.WillRepeatedly(Return(&backend_));
}
// Returns a spec that the network supports.
ComponentSpec GetSupportedSpec() {
ComponentSpec component_spec;
component_spec.set_name(kTestComponentName);
component_spec.set_num_actions(kLogitsDim);
component_spec.mutable_network_unit()->set_registered_name("AddOne");
component_spec.mutable_backend()->set_registered_name("SequenceBackend");
component_spec.mutable_component_builder()->set_registered_name(
"SequenceBulkDynamicComponent");
auto &component_parameters =
*component_spec.mutable_component_builder()->mutable_parameters();
component_parameters["sequence_extractors"] = "ExtractZeros";
component_parameters["sequence_linkers"] = "LinkZeros";
component_parameters["sequence_predictor"] = "CaptureLogits";
FixedFeatureChannel *fixed_feature = component_spec.add_fixed_feature();
fixed_feature->set_size(1);
fixed_feature->set_embedding_dim(kFixedDim);
fixed_feature->set_vocabulary_size(kFixedVocabularySize);
LinkedFeatureChannel *linked_feature = component_spec.add_linked_feature();
linked_feature->set_size(1);
linked_feature->set_embedding_dim(-1);
linked_feature->set_source_component(kPreviousComponentName);
linked_feature->set_source_layer(kPreviousLayerName);
return component_spec;
}
// Creates a network unit, initializes it based on the |component_spec_text|,
// and evaluates it. On error, returns non-OK.
tensorflow::Status Run(const ComponentSpec &component_spec) {
AddComponent(kPreviousComponentName);
AddLayer(kPreviousLayerName, kLinkedDim);
AddComponent(kTestComponentName);
AddFixedEmbeddingMatrix(0, kFixedVocabularySize, kFixedDim, kFixedValue);
std::unique_ptr<Component> component;
TF_RETURN_IF_ERROR(
Component::CreateOrError("SequenceBulkDynamicComponent", &component));
TF_RETURN_IF_ERROR(component->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
// Allocates network states for a few steps.
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
FillLayer(kPreviousComponentName, kPreviousLayerName, kLinkedValue);
StartComponent(0);
session_state_.extensions.Reset(&extension_manager_);
return component->Evaluate(&session_state_, &compute_session_, nullptr);
}
// Input batch injected into Evaluate() by default.
InputBatchCache input_;
// Backend injected into Evaluate().
SequenceBackend backend_;
};
// Tests that the supported spec is supported.
TEST_F(SequenceBulkDynamicComponentTest, Supported) {
const ComponentSpec component_spec = GetSupportedSpec();
string component_type;
TF_ASSERT_OK(Component::Select(component_spec, &component_type));
EXPECT_EQ(component_type, "SequenceBulkDynamicComponent");
TF_ASSERT_OK(Run(component_spec));
const Matrix<float> logits = CaptureLogits::GetCapturedLogits();
ASSERT_EQ(logits.num_rows(), kNumSteps);
ASSERT_EQ(logits.num_columns(), kFixedDim + kLinkedDim);
for (size_t row = 0; row < kNumSteps; ++row) {
size_t column = 0;
for (; column < kFixedDim; ++column) {
EXPECT_EQ(logits.row(row)[column], kFixedValue + 1.0);
}
for (; column < kFixedDim + kLinkedDim; ++column) {
EXPECT_EQ(logits.row(row)[column], kLinkedValue + 1.0);
}
}
}
// Tests that links cannot be recurrent.
TEST_F(SequenceBulkDynamicComponentTest, ForbidRecurrences) {
ComponentSpec component_spec = GetSupportedSpec();
component_spec.mutable_linked_feature(0)->set_source_component(
kTestComponentName);
string component_type;
EXPECT_THAT(
Component::Select(component_spec, &component_type),
test::IsErrorWithSubstr("Could not find a best spec for component"));
}
// Tests that the component prefers others.
TEST_F(SequenceBulkDynamicComponentTest, PrefersOthers) {
ComponentSpec component_spec = GetSupportedSpec();
component_spec.add_resource();
// Adding a resource triggers the ImTheWorst component, which also prefers
// itself and leads to a selection conflict.
string component_type;
EXPECT_THAT(
Component::Select(component_spec, &component_type),
test::IsErrorWithSubstr("both think they should be dis-preferred"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns true if the |component_spec| has recurrent links.
bool IsRecurrent(const ComponentSpec &component_spec) {
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.source_component() == component_spec.name()) return true;
}
return false;
}
// Returns the sequence-based version of the |component_type| with specification
// |component_spec|, or an empty string if there is no sequence-based version.
string GetSequenceComponentType(const string &component_type,
const ComponentSpec &component_spec) {
// TODO(googleuser): Implement a SequenceDynamicComponent that can handle
// recurrent links. This may require changes to the NetworkUnit API.
static const char *kSupportedComponentTypes[] = {
"BulkDynamicComponent", //
"BulkLstmComponent", //
"MyelinDynamicComponent", //
};
for (const char *supported_type : kSupportedComponentTypes) {
if (component_type == supported_type) {
return tensorflow::strings::StrCat("Sequence", supported_type);
}
}
// Also support non-recurrent DynamicComponents. The BulkDynamicComponent
// requires determinism, but the SequenceBulkDynamicComponent does not, so
// it's not sufficient to only upgrade from BulkDynamicComponent.
if (component_type == "DynamicComponent" && !IsRecurrent(component_spec)) {
return "SequenceBulkDynamicComponent";
}
return string();
}
// Returns the |status| but coerces NOT_FOUND to OK. Sets |found| to false iff
// the |status| was NOT_FOUND.
tensorflow::Status AllowNotFound(const tensorflow::Status &status,
bool *found) {
*found = status.code() != tensorflow::error::NOT_FOUND;
return *found ? status : tensorflow::Status::OK();
}
// Transformer that checks whether a sequence-based component implementation
// could be used and, if compatible, modifies the ComponentSpec accordingly.
class SequenceComponentTransformer : public ComponentTransformer {
public:
// Implements ComponentTransformer.
tensorflow::Status Transform(const string &component_type,
ComponentSpec *component_spec) override;
};
tensorflow::Status SequenceComponentTransformer::Transform(
const string &component_type, ComponentSpec *component_spec) {
const int num_features = component_spec->fixed_feature_size() +
component_spec->linked_feature_size();
if (num_features == 0) return tensorflow::Status::OK();
// Look for supporting SequenceExtractors.
bool found = false;
string extractor_types;
for (const FixedFeatureChannel &channel : component_spec->fixed_feature()) {
string type;
TF_RETURN_IF_ERROR(AllowNotFound(
SequenceExtractor::Select(channel, *component_spec, &type), &found));
if (!found) return tensorflow::Status::OK();
tensorflow::strings::StrAppend(&extractor_types, type, ",");
}
if (!extractor_types.empty()) extractor_types.pop_back(); // remove comma
// Look for supporting SequenceLinkers.
string linker_types;
for (const LinkedFeatureChannel &channel : component_spec->linked_feature()) {
string type;
TF_RETURN_IF_ERROR(AllowNotFound(
SequenceLinker::Select(channel, *component_spec, &type), &found));
if (!found) return tensorflow::Status::OK();
tensorflow::strings::StrAppend(&linker_types, type, ",");
}
if (!linker_types.empty()) linker_types.pop_back(); // remove comma
// Look for a supporting SequencePredictor, if predictions are necessary.
string predictor_type;
if (!TransitionSystemTraits(*component_spec).is_deterministic) {
TF_RETURN_IF_ERROR(AllowNotFound(
SequencePredictor::Select(*component_spec, &predictor_type), &found));
if (!found) return tensorflow::Status::OK();
}
// Look for a supporting sequence-based component type.
const string sequence_component_type =
GetSequenceComponentType(component_type, *component_spec);
if (sequence_component_type.empty()) return tensorflow::Status::OK();
// Success; make modifications.
component_spec->mutable_backend()->set_registered_name("SequenceBackend");
RegisteredModuleSpec *builder = component_spec->mutable_component_builder();
builder->set_registered_name(sequence_component_type);
(*builder->mutable_parameters())["sequence_extractors"] = extractor_types;
(*builder->mutable_parameters())["sequence_linkers"] = linker_types;
(*builder->mutable_parameters())["sequence_predictor"] = predictor_type;
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(SequenceComponentTransformer);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Arbitrary supported component type.
constexpr char kSupportedComponentType[] = "MyelinDynamicComponent";
// Sequence-based version of the component type.
constexpr char kTransformedComponentType[] = "SequenceMyelinDynamicComponent";
// Trivial extractor that supports components named "supported".
class SupportIfNamedSupportedExtractor : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "supported";
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(SupportIfNamedSupportedExtractor);
// Trivial extractor that supports components if they have a resource. This is
// used to generate a "multiple supported extractors" conflict.
class SupportIfHasResourcesExtractor : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.resource_size() > 0;
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(SupportIfHasResourcesExtractor);
// Trivial linker that supports components named "supported".
class SupportIfNamedSupportedLinker : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "supported";
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(SupportIfNamedSupportedLinker);
// Trivial predictor that supports components named "supported".
class SupportIfNamedSupportedPredictor : public SequencePredictor {
public:
// Implements SequencePredictor.
bool Supports(const ComponentSpec &component_spec) const override {
return component_spec.name() == "supported";
}
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status Predict(Matrix<float>, InputBatchCache *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(SupportIfNamedSupportedPredictor);
// Returns a ComponentSpec that is supported by the transformer.
ComponentSpec MakeSupportedSpec() {
ComponentSpec component_spec;
component_spec.set_name("supported");
component_spec.set_num_actions(10);
component_spec.add_fixed_feature();
component_spec.add_fixed_feature();
component_spec.add_linked_feature();
component_spec.add_linked_feature();
component_spec.mutable_component_builder()->set_registered_name(
kSupportedComponentType);
return component_spec;
}
// Tests that a compatible spec is modified to use a new backend and component
// builder with SequenceExtractors, SequenceLinkers, and SequencePredictor.
TEST(SequenceComponentTransformerTest, Compatible) {
ComponentSpec component_spec = MakeSupportedSpec();
ComponentSpec modified_spec = component_spec;
modified_spec.mutable_backend()->set_registered_name("SequenceBackend");
modified_spec.mutable_component_builder()->set_registered_name(
kTransformedComponentType);
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_extractors",
"SupportIfNamedSupportedExtractor,SupportIfNamedSupportedExtractor"});
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_linkers",
"SupportIfNamedSupportedLinker,SupportIfNamedSupportedLinker"});
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_predictor", "SupportIfNamedSupportedPredictor"});
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(modified_spec));
}
// Tests that a compatible deterministic spec is modified to use a new backend
// and component builder with SequenceExtractors and SequenceLinkers only.
TEST(SequenceComponentTransformerTest, CompatibleNoPredictor) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.set_num_actions(1);
ComponentSpec modified_spec = component_spec;
modified_spec.mutable_backend()->set_registered_name("SequenceBackend");
modified_spec.mutable_component_builder()->set_registered_name(
kTransformedComponentType);
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_extractors",
"SupportIfNamedSupportedExtractor,SupportIfNamedSupportedExtractor"});
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_linkers",
"SupportIfNamedSupportedLinker,SupportIfNamedSupportedLinker"});
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_predictor", ""});
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(modified_spec));
}
// Tests that a ComponentSpec with no features is incompatible.
TEST(SequenceComponentTransformerTest, IncompatibleNoFeatures) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.clear_fixed_feature();
component_spec.clear_linked_feature();
const ComponentSpec unchanged_spec = component_spec;
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(unchanged_spec));
}
// Tests that a ComponentSpec with the wrong component builder is incompatible.
TEST(SequenceComponentTransformerTest, IncompatibleComponentBuilder) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_component_builder()->set_registered_name("bad");
const ComponentSpec unchanged_spec = component_spec;
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(unchanged_spec));
}
// Tests that a ComponentSpec is incompatible if it is not supported by any
// SequenceExtractor.
TEST(SequenceComponentTransformerTest,
IncompatibleNoSupportingSequenceExtractor) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.set_name("bad");
const ComponentSpec unchanged_spec = component_spec;
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(unchanged_spec));
}
// Tests that a ComponentSpec fails if multiple SequenceExtractors support it.
TEST(SequenceComponentTransformerTest,
FailIfMultipleSupportingSequenceExtractors) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.add_resource(); // triggers SupportIfHasResourcesExtractor
EXPECT_THAT(
ComponentTransformer::ApplyAll(&component_spec),
test::IsErrorWithSubstr("Multiple SequenceExtractors support channel"));
}
// Tests that a DynamicComponent is not upgraded if it is recurrent.
TEST(SequenceComponentTransformerTest, RecurrentDynamicComponent) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_component_builder()->set_registered_name(
"DynamicComponent");
component_spec.mutable_linked_feature(0)->set_source_component(
component_spec.name());
const ComponentSpec unchanged_spec = component_spec;
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(unchanged_spec));
}
// Tests that a DynamicComponent is upgraded to SequenceBulkDynamicComponent if
// it is non-recurrent.
TEST(SequenceComponentTransformerTest, NonRecurrentDynamicComponent) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_component_builder()->set_registered_name(
"DynamicComponent");
ComponentSpec modified_spec = component_spec;
modified_spec.mutable_backend()->set_registered_name("SequenceBackend");
modified_spec.mutable_component_builder()->set_registered_name(
"SequenceBulkDynamicComponent");
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_extractors",
"SupportIfNamedSupportedExtractor,SupportIfNamedSupportedExtractor"});
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_linkers",
"SupportIfNamedSupportedLinker,SupportIfNamedSupportedLinker"});
modified_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_predictor", "SupportIfNamedSupportedPredictor"});
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(modified_spec));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_extractor.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status SequenceExtractor::Select(
const FixedFeatureChannel &channel, const ComponentSpec &component_spec,
string *name) {
string supporting_name;
for (const Registry::Registrar *registrar = registry()->components;
registrar != nullptr; registrar = registrar->next()) {
Factory *factory_function = registrar->object();
std::unique_ptr<SequenceExtractor> current_extractor(factory_function());
if (!current_extractor->Supports(channel, component_spec)) continue;
if (!supporting_name.empty()) {
return tensorflow::errors::Internal(
"Multiple SequenceExtractors support channel ",
channel.ShortDebugString(), " of ComponentSpec (", supporting_name,
" and ", registrar->name(), "): ", component_spec.ShortDebugString());
}
supporting_name = registrar->name();
}
if (supporting_name.empty()) {
return tensorflow::errors::NotFound(
"No SequenceExtractor supports channel ", channel.ShortDebugString(),
" of ComponentSpec: ", component_spec.ShortDebugString());
}
// Success; make modifications.
*name = supporting_name;
return tensorflow::Status::OK();
}
tensorflow::Status SequenceExtractor::New(
const string &name, const FixedFeatureChannel &channel,
const ComponentSpec &component_spec,
std::unique_ptr<SequenceExtractor> *extractor) {
std::unique_ptr<SequenceExtractor> matching_extractor;
TF_RETURN_IF_ERROR(
SequenceExtractor::CreateOrError(name, &matching_extractor));
TF_RETURN_IF_ERROR(matching_extractor->Initialize(channel, component_spec));
// Success; make modifications.
*extractor = std::move(matching_extractor);
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Sequence Extractor",
dragnn::runtime::SequenceExtractor);
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_EXTRACTOR_H_
#define DRAGNN_RUNTIME_SEQUENCE_EXTRACTOR_H_
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Interface for feature extraction for sequence inputs.
//
// This extractor can be used to avoid ComputeSession overhead in simple cases;
// for example, extracting a sequence of character or word IDs for an LSTM.
class SequenceExtractor : public RegisterableClass<SequenceExtractor> {
public:
// Sets |extractor| to an instance of the subclass named |name| initialized
// from the |channel| of the |component_spec|. On error, returns non-OK and
// modifies nothing.
static tensorflow::Status New(const string &name,
const FixedFeatureChannel &channel,
const ComponentSpec &component_spec,
std::unique_ptr<SequenceExtractor> *extractor);
SequenceExtractor(const SequenceExtractor &) = delete;
SequenceExtractor &operator=(const SequenceExtractor &) = delete;
virtual ~SequenceExtractor() = default;
// Sets |name| to the registered name of the SequenceExtractor that supports
// the |channel| of the |component_spec|. On error, returns non-OK and
// modifies nothing. The returned statuses include:
// * OK: If a supporting SequenceExtractor was found.
// * INTERNAL: If an error occurred while searching for a compatible match.
// * NOT_FOUND: If the search was error-free, but no compatible match was
// found.
static tensorflow::Status Select(const FixedFeatureChannel &channel,
const ComponentSpec &component_spec,
string *name);
// Overwrites |ids| with the sequence of features extracted from the |input|.
// On error, returns non-OK.
virtual tensorflow::Status GetIds(InputBatchCache *input,
std::vector<int32> *ids) const = 0;
protected:
SequenceExtractor() = default;
private:
// Helps prevent use of the Create() method; use New() instead.
using RegisterableClass<SequenceExtractor>::Create;
// Returns true if this supports the |channel| of the |component_spec|.
// Implementations must coordinate to ensure that at most one supports any
// given |component_spec|.
virtual bool Supports(const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) const = 0;
// Initializes this from the |channel| of the |component_spec|. On error,
// returns non-OK.
virtual tensorflow::Status Initialize(
const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) = 0;
};
} // namespace runtime
} // namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Sequence Extractor",
dragnn::runtime::SequenceExtractor);
} // namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::SequenceExtractor, #subclass, subclass)
#endif // DRAGNN_RUNTIME_SEQUENCE_EXTRACTOR_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_extractor.h"
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Supports components named "success" and initializes successfully.
class Success : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "success";
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(Success);
// Supports components named "failure" and fails to initialize.
class Failure : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "failure";
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::errors::Internal("Boom!");
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(Failure);
// Supports components named "duplicate" and initializes successfully.
class Duplicate : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "duplicate";
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(Duplicate);
// Duplicate of the above.
using Duplicate2 = Duplicate;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(Duplicate2);
// Tests that a component can be successfully created.
TEST(SequenceExtractorTest, Success) {
string name;
std::unique_ptr<SequenceExtractor> extractor;
ComponentSpec component_spec;
component_spec.set_name("success");
TF_ASSERT_OK(SequenceExtractor::Select({}, component_spec, &name));
ASSERT_EQ(name, "Success");
TF_EXPECT_OK(SequenceExtractor::New(name, {}, component_spec, &extractor));
EXPECT_NE(extractor, nullptr);
}
// Tests that errors in Initialize() are reported.
TEST(SequenceExtractorTest, FailToInitialize) {
string name;
std::unique_ptr<SequenceExtractor> extractor;
ComponentSpec component_spec;
component_spec.set_name("failure");
TF_ASSERT_OK(SequenceExtractor::Select({}, component_spec, &name));
EXPECT_EQ(name, "Failure");
EXPECT_THAT(SequenceExtractor::New(name, {}, component_spec, &extractor),
test::IsErrorWithSubstr("Boom!"));
EXPECT_EQ(extractor, nullptr);
}
// Tests that unsupported specs are reported as NOT_FOUND errors.
TEST(SequenceExtractorTest, UnsupportedSpec) {
string name = "not overwritten";
ComponentSpec component_spec;
component_spec.set_name("unsupported");
EXPECT_THAT(SequenceExtractor::Select({}, component_spec, &name),
test::IsErrorWithCodeAndSubstr(
tensorflow::error::NOT_FOUND,
"No SequenceExtractor supports channel"));
EXPECT_EQ(name, "not overwritten");
}
// Tests that unsupported subclass names are reported as errors.
TEST(SequenceExtractorTest, UnsupportedSubclass) {
std::unique_ptr<SequenceExtractor> extractor;
ComponentSpec component_spec;
EXPECT_THAT(
SequenceExtractor::New("Unsupported", {}, component_spec, &extractor),
test::IsErrorWithSubstr("Unknown DRAGNN Runtime Sequence Extractor"));
EXPECT_EQ(extractor, nullptr);
}
// Tests that multiple supporting extractors are reported as INTERNAL errors.
TEST(SequenceExtractorTest, Duplicate) {
string name = "not overwritten";
ComponentSpec component_spec;
component_spec.set_name("duplicate");
EXPECT_THAT(SequenceExtractor::Select({}, component_spec, &name),
test::IsErrorWithCodeAndSubstr(
tensorflow::error::INTERNAL,
"Multiple SequenceExtractors support channel"));
EXPECT_EQ(name, "not overwritten");
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_features.h"
#include <utility>
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status SequenceFeatureManager::Reset(
const FixedEmbeddingManager *fixed_embedding_manager,
const ComponentSpec &component_spec,
const std::vector<string> &sequence_extractor_types) {
const size_t num_channels = fixed_embedding_manager->channel_configs_.size();
if (component_spec.fixed_feature_size() != num_channels) {
return tensorflow::errors::InvalidArgument(
"Channel mismatch between FixedEmbeddingManager (", num_channels,
") and ComponentSpec (", component_spec.fixed_feature_size(), ")");
}
if (sequence_extractor_types.size() != num_channels) {
return tensorflow::errors::InvalidArgument(
"Channel mismatch between FixedEmbeddingManager (", num_channels,
") and SequenceExtractors (", sequence_extractor_types.size(), ")");
}
for (const FixedFeatureChannel &channel : component_spec.fixed_feature()) {
if (channel.size() > 1) {
return tensorflow::errors::InvalidArgument(
"Multi-embedding fixed features are not supported for channel: ",
channel.ShortDebugString());
}
}
std::vector<ChannelConfig> local_configs; // avoid modification on error
for (size_t channel_id = 0; channel_id < num_channels; ++channel_id) {
local_configs.emplace_back();
ChannelConfig &channel_config = local_configs.back();
const FixedEmbeddingManager::ChannelConfig &wrapped_config =
fixed_embedding_manager->channel_configs_[channel_id];
channel_config.is_embedded = wrapped_config.is_embedded;
channel_config.embedding_matrix = wrapped_config.embedding_matrix;
TF_RETURN_IF_ERROR(
SequenceExtractor::New(sequence_extractor_types[channel_id],
component_spec.fixed_feature(channel_id),
component_spec, &channel_config.extractor));
}
// Success; make modifications.
zeros_ = fixed_embedding_manager->zeros_.view();
channel_configs_ = std::move(local_configs);
return tensorflow::Status::OK();
}
tensorflow::Status SequenceFeatures::Reset(
const SequenceFeatureManager *manager, InputBatchCache *input) {
manager_ = manager;
zeros_ = manager->zeros_;
num_channels_ = manager->channel_configs_.size();
num_steps_ = 0;
// Make sure |channels_| is big enough. Note that |channels_| never shrinks,
// so the Channel.ids sub-vector is never deallocated.
if (num_channels_ > channels_.size()) channels_.resize(num_channels_);
for (int channel_id = 0; channel_id < num_channels_; ++channel_id) {
Channel &channel = channels_[channel_id];
const SequenceFeatureManager::ChannelConfig &channel_config =
manager->channel_configs_[channel_id];
channel.embedding_matrix = channel_config.embedding_matrix;
TF_RETURN_IF_ERROR(channel_config.extractor->GetIds(input, &channel.ids));
if (channel_id == 0) {
num_steps_ = channel.ids.size();
} else if (channel.ids.size() != num_steps_) {
return tensorflow::errors::FailedPrecondition(
"Inconsistent feature sequence lengths at channel ID ", channel_id,
": got ", channel.ids.size(), " but expected ", num_steps_);
}
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for configuring and extracting fixed embeddings for sequence-based
// models. Analogous to FixedEmbeddingManager and FixedEmbeddings, but uses
// SequenceExtractor instead of ComputeSession.
#ifndef DRAGNN_RUNTIME_SEQUENCE_FEATURES_H_
#define DRAGNN_RUNTIME_SEQUENCE_FEATURES_H_
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Manager for fixed embeddings for sequence-based models. This is a wrapper
// around the FixedEmbeddingManager.
class SequenceFeatureManager {
public:
// Creates an empty manager.
SequenceFeatureManager() = default;
// Resets this to wrap the |fixed_embedding_manager|, which must outlive this.
// The |sequence_extractor_types| should name one SequenceExtractor subclass
// per channel; e.g., "SyntaxNetCharacterSequenceExtractor". This initializes
// each SequenceExtractor from the |component_spec|. On error, returns non-OK
// and does not modify this.
tensorflow::Status Reset(
const FixedEmbeddingManager *fixed_embedding_manager,
const ComponentSpec &component_spec,
const std::vector<string> &sequence_extractor_types);
// Accessors.
size_t num_channels() const { return channel_configs_.size(); }
private:
friend class SequenceFeatures;
// Configuration for a single fixed embedding channel.
struct ChannelConfig {
// Whether this channel is embedded.
bool is_embedded = true;
// Embedding matrix of this channel. Only used if |is_embedded| is true.
Matrix<float> embedding_matrix;
// Extractor for sequences of feature IDs.
std::unique_ptr<SequenceExtractor> extractor;
};
// Array of zeros that can be substituted for missing feature IDs. This is a
// reference to the corresponding array in the FixedEmbeddingManager.
AlignedView zeros_;
// Ordered list of configurations for each channel.
std::vector<ChannelConfig> channel_configs_;
};
// A set of fixed embeddings for a sequence-based model. Configured by a
// SequenceFeatureManager.
class SequenceFeatures {
public:
// Creates an empty set of embeddings.
SequenceFeatures() = default;
// Resets this to the sequences of fixed features managed by the |manager| on
// the |input|. The |manager| must live until this is destroyed or Reset(),
// and should not be modified during that time. On error, returns non-OK.
tensorflow::Status Reset(const SequenceFeatureManager *manager,
InputBatchCache *input);
// Returns the feature ID or embedding for the |target_index|'th element of
// the |channel_id|'th channel. Each method is only valid for a non-embedded
// or embedded channel, respectively.
int32 GetId(size_t channel_id, size_t target_index) const;
Vector<float> GetEmbedding(size_t channel_id, size_t target_index) const;
// Accessors.
size_t num_channels() const { return num_channels_; }
size_t num_steps() const { return num_steps_; }
private:
// Data associated with a single fixed embedding channel.
struct Channel {
// Embedding matrix of this channel. Only used for embedded channels.
Matrix<float> embedding_matrix;
// Feature IDs for each step.
std::vector<int32> ids;
};
// Manager from the most recent Reset().
const SequenceFeatureManager *manager_ = nullptr;
// Zero vector from the most recent Reset().
AlignedView zeros_;
// Number of channels and steps from the most recent Reset().
size_t num_channels_ = 0;
size_t num_steps_ = 0;
// Ordered list of fixed embedding channels. This may contain more than
// |num_channels_| entries, to avoid deallocation/reallocation cycles, but
// only the first |num_channels_| entries are valid.
std::vector<Channel> channels_;
};
// Implementation details below.
inline int32 SequenceFeatures::GetId(size_t channel_id,
size_t target_index) const {
DCHECK_LT(channel_id, num_channels());
DCHECK_LT(target_index, num_steps());
DCHECK(!manager_->channel_configs_[channel_id].is_embedded);
const Channel &channel = channels_[channel_id];
return channel.ids[target_index];
}
inline Vector<float> SequenceFeatures::GetEmbedding(size_t channel_id,
size_t target_index) const {
DCHECK_LT(channel_id, num_channels());
DCHECK_LT(target_index, num_steps());
DCHECK(manager_->channel_configs_[channel_id].is_embedded);
const Channel &channel = channels_[channel_id];
const int32 id = channel.ids[target_index];
return id < 0 ? Vector<float>(zeros_, channel.embedding_matrix.num_columns())
: channel.embedding_matrix.row(id);
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_FEATURES_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_features.h"
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Number of transition steps to take in each component in the network.
const size_t kNumSteps = 10;
// A working one-channel ComponentSpec. This is intentionally identical to the
// first channel of |kMultiSpec|, so they can use the same embedding matrix.
const char kSingleSpec[] = R"(fixed_feature {
vocabulary_size: 13
embedding_dim: 11
size: 1
})";
const size_t kSingleRows = 13;
const size_t kSingleColumns = 11;
constexpr float kSingleValue = 1.25;
// A working multi-channel ComponentSpec.
const char kMultiSpec[] = R"(fixed_feature {
vocabulary_size: 13
embedding_dim: 11
size: 1
}
fixed_feature {
embedding_dim: -1
size: 1
}
fixed_feature {
embedding_dim: -1
size: 1
})";
// Fails to initialize.
class FailToInitialize : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &component_spec) const override {
LOG(FATAL) << "Should never be called.";
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::errors::Internal("No initialization for you!");
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
LOG(FATAL) << "Should never be called.";
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(FailToInitialize);
// Initializes OK, then fails to extract features.
class FailToGetIds : public FailToInitialize {
public:
// Implements SequenceExtractor.
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::errors::Internal("No features for you!");
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(FailToGetIds);
// Initializes OK and extracts the previous step.
class ExtractPrevious : public FailToGetIds {
public:
// Implements SequenceExtractor.
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *ids) const override {
ids->resize(kNumSteps);
for (int i = 0; i < kNumSteps; ++i) (*ids)[i] = i - 1;
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(ExtractPrevious);
// Initializes OK but produces the wrong number of features.
class WrongNumberOfIds : public FailToGetIds {
public:
// Implements SequenceExtractor.
tensorflow::Status GetIds(InputBatchCache *input,
std::vector<int32> *ids) const override {
ids->resize(kNumSteps + 1);
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(WrongNumberOfIds);
class SequenceFeatureManagerTest : public NetworkTestBase {
protected:
// Creates a SequenceFeatureManager and returns the result of Reset()-ing it
// using the |component_spec_text|.
tensorflow::Status ResetManager(
const string &component_spec_text,
const std::vector<string> &sequence_extractor_types) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
component_spec.set_name(kTestComponentName);
AddFixedEmbeddingMatrix(0, kSingleRows, kSingleColumns, kSingleValue);
AddComponent(kTestComponentName);
TF_RETURN_IF_ERROR(fixed_embedding_manager_.Reset(
component_spec, &variable_store_, &network_state_manager_));
return manager_.Reset(&fixed_embedding_manager_, component_spec,
sequence_extractor_types);
}
FixedEmbeddingManager fixed_embedding_manager_;
SequenceFeatureManager manager_;
};
// Tests that SequenceFeatureManager is empty by default.
TEST_F(SequenceFeatureManagerTest, EmptyByDefault) {
EXPECT_EQ(manager_.num_channels(), 0);
}
// Tests that SequenceFeatureManager is empty when reset to an empty spec.
TEST_F(SequenceFeatureManagerTest, EmptySpec) {
TF_EXPECT_OK(ResetManager("", {}));
EXPECT_EQ(manager_.num_channels(), 0);
}
// Tests that SequenceFeatureManager works with a single channel.
TEST_F(SequenceFeatureManagerTest, OneChannel) {
TF_EXPECT_OK(ResetManager(kSingleSpec, {"ExtractPrevious"}));
EXPECT_EQ(manager_.num_channels(), 1);
}
// Tests that SequenceFeatureManager works with multiple channels.
TEST_F(SequenceFeatureManagerTest, MultipleChannels) {
TF_EXPECT_OK(ResetManager(
kMultiSpec, {"ExtractPrevious", "ExtractPrevious", "ExtractPrevious"}));
EXPECT_EQ(manager_.num_channels(), 3);
}
// Tests that SequenceFeatureManager fails if the FixedEmbeddingManager and
// ComponentSpec are mismatched.
TEST_F(SequenceFeatureManagerTest, MismatchedFixedManagerAndComponentSpec) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(kMultiSpec, &component_spec));
component_spec.set_name(kTestComponentName);
AddFixedEmbeddingMatrix(0, kSingleRows, kSingleColumns, kSingleValue);
AddComponent(kTestComponentName);
TF_ASSERT_OK(fixed_embedding_manager_.Reset(component_spec, &variable_store_,
&network_state_manager_));
// Remove one fixed feature, resulting in a mismatch.
component_spec.mutable_fixed_feature()->RemoveLast();
EXPECT_THAT(
manager_.Reset(&fixed_embedding_manager_, component_spec,
{"ExtractPrevious", "ExtractPrevious", "ExtractPrevious"}),
test::IsErrorWithSubstr("Channel mismatch between FixedEmbeddingManager "
"(3) and ComponentSpec (2)"));
}
// Tests that SequenceFeatureManager fails if the FixedEmbeddingManager and
// SequenceExtractors are mismatched.
TEST_F(SequenceFeatureManagerTest,
MismatchedFixedManagerAndSequenceExtractors) {
EXPECT_THAT(
ResetManager(kMultiSpec, {"ExtractPrevious", "ExtractPrevious"}),
test::IsErrorWithSubstr("Channel mismatch between FixedEmbeddingManager "
"(3) and SequenceExtractors (2)"));
}
// Tests that SequenceFeatureManager fails if a channel has multiple embeddings.
TEST_F(SequenceFeatureManagerTest, UnsupportedMultiEmbeddingChannel) {
const string kBadSpec = R"(fixed_feature {
vocabulary_size: 13
embedding_dim: 11
size: 2 # bad
})";
EXPECT_THAT(ResetManager(kBadSpec, {"ExtractPrevious"}),
test::IsErrorWithSubstr(
"Multi-embedding fixed features are not supported"));
}
// Tests that SequenceFeatureManager fails if one of the SequenceExtractors
// fails to initialize.
TEST_F(SequenceFeatureManagerTest, FailToInitializeSequenceExtractor) {
EXPECT_THAT(ResetManager(kMultiSpec, {"ExtractPrevious", "FailToInitialize",
"ExtractPrevious"}),
test::IsErrorWithSubstr("No initialization for you!"));
}
// Tests that SequenceFeatureManager is OK even if the SequenceExtractors would
// fail in GetIds().
TEST_F(SequenceFeatureManagerTest, ManagerDoesntCareAboutGetIds) {
TF_EXPECT_OK(ResetManager(
kMultiSpec, {"FailToGetIds", "FailToGetIds", "FailToGetIds"}));
}
class SequenceFeaturesTest : public SequenceFeatureManagerTest {
protected:
// Resets the |sequence_features_| on the |manager_| and |input_batch_cache_|
// and returns the resulting status.
tensorflow::Status ResetFeatures() {
return sequence_features_.Reset(&manager_, &input_batch_cache_);
}
InputBatchCache input_batch_cache_;
SequenceFeatures sequence_features_;
};
// Tests that SequenceFeatures is empty by default.
TEST_F(SequenceFeaturesTest, EmptyByDefault) {
EXPECT_EQ(sequence_features_.num_channels(), 0);
EXPECT_EQ(sequence_features_.num_steps(), 0);
}
// Tests that SequenceFeatures is empty when reset by an empty manager.
TEST_F(SequenceFeaturesTest, EmptyManager) {
TF_ASSERT_OK(ResetManager("", {}));
TF_EXPECT_OK(ResetFeatures());
EXPECT_EQ(sequence_features_.num_channels(), 0);
EXPECT_EQ(sequence_features_.num_steps(), 0);
}
// Tests that SequenceFeatures fails when one of the SequenceExtractors fails.
TEST_F(SequenceFeaturesTest, FailToGetIds) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"ExtractPrevious", "ExtractPrevious", "FailToGetIds"}));
EXPECT_THAT(ResetFeatures(), test::IsErrorWithSubstr("No features for you!"));
}
// Tests that SequenceFeatures fails when the SequenceExtractors produce
// different numbers of features.
TEST_F(SequenceFeaturesTest, MismatchedNumbersOfFeatures) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"ExtractPrevious", "ExtractPrevious", "WrongNumberOfIds"}));
EXPECT_THAT(ResetFeatures(), test::IsErrorWithSubstr(
"Inconsistent feature sequence lengths at "
"channel ID 2: got 11 but expected 10"));
}
// Tests that SequenceFeatures works as expected on one channel.
TEST_F(SequenceFeaturesTest, SingleChannel) {
TF_ASSERT_OK(ResetManager(kSingleSpec, {"ExtractPrevious"}));
TF_ASSERT_OK(ResetFeatures());
ASSERT_EQ(sequence_features_.num_channels(), 1);
ASSERT_EQ(sequence_features_.num_steps(), kNumSteps);
// ExtractPrevious extracts -1 for the 0'th target index, which indicates a
// missing ID and should be mapped to a zero vector.
ExpectVector(sequence_features_.GetEmbedding(0, 0), kSingleColumns, 0.0);
EXPECT_DEBUG_DEATH(sequence_features_.GetId(0, 0), "is_embedded");
// The remaining feature IDs map to valid embedding rows.
for (int i = 1; i < kNumSteps; ++i) {
ExpectVector(sequence_features_.GetEmbedding(0, i), kSingleColumns,
kSingleValue);
EXPECT_DEBUG_DEATH(sequence_features_.GetId(0, i), "is_embedded");
}
}
// Tests that SequenceFeatures works as expected on multiple channels.
TEST_F(SequenceFeaturesTest, ManyChannels) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"ExtractPrevious", "ExtractPrevious", "ExtractPrevious"}));
TF_ASSERT_OK(ResetFeatures());
ASSERT_EQ(sequence_features_.num_channels(), 3);
ASSERT_EQ(sequence_features_.num_steps(), kNumSteps);
// ExtractPrevious extracts -1 for the 0'th target index, which indicates a
// missing ID and should be mapped to a zero vector.
ExpectVector(sequence_features_.GetEmbedding(0, 0), kSingleColumns, 0.0);
EXPECT_EQ(sequence_features_.GetId(1, 0), -1);
EXPECT_EQ(sequence_features_.GetId(2, 0), -1);
EXPECT_DEBUG_DEATH(sequence_features_.GetId(0, 0), "is_embedded");
EXPECT_DEBUG_DEATH(sequence_features_.GetEmbedding(1, 0), "is_embedded");
EXPECT_DEBUG_DEATH(sequence_features_.GetEmbedding(2, 0), "is_embedded");
// The remaining features point to the previous item.
for (int i = 1; i < kNumSteps; ++i) {
ExpectVector(sequence_features_.GetEmbedding(0, i), kSingleColumns,
kSingleValue);
EXPECT_EQ(sequence_features_.GetId(1, i), i - 1);
EXPECT_EQ(sequence_features_.GetId(2, i), i - 1);
EXPECT_DEBUG_DEATH(sequence_features_.GetId(0, i), "is_embedded");
EXPECT_DEBUG_DEATH(sequence_features_.GetEmbedding(1, i), "is_embedded");
EXPECT_DEBUG_DEATH(sequence_features_.GetEmbedding(2, i), "is_embedded");
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_linker.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status SequenceLinker::Select(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec,
string *name) {
string supporting_name;
for (const Registry::Registrar *registrar = registry()->components;
registrar != nullptr; registrar = registrar->next()) {
Factory *factory_function = registrar->object();
std::unique_ptr<SequenceLinker> current_linker(factory_function());
if (!current_linker->Supports(channel, component_spec)) continue;
if (!supporting_name.empty()) {
return tensorflow::errors::Internal(
"Multiple SequenceLinkers support channel ",
channel.ShortDebugString(), " of ComponentSpec (", supporting_name,
" and ", registrar->name(), "): ", component_spec.ShortDebugString());
}
supporting_name = registrar->name();
}
if (supporting_name.empty()) {
return tensorflow::errors::NotFound(
"No SequenceLinker supports channel ", channel.ShortDebugString(),
" of ComponentSpec: ", component_spec.ShortDebugString());
}
// Success; make modifications.
*name = supporting_name;
return tensorflow::Status::OK();
}
tensorflow::Status SequenceLinker::New(
const string &name, const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec,
std::unique_ptr<SequenceLinker> *linker) {
std::unique_ptr<SequenceLinker> matching_linker;
TF_RETURN_IF_ERROR(SequenceLinker::CreateOrError(name, &matching_linker));
TF_RETURN_IF_ERROR(matching_linker->Initialize(channel, component_spec));
// Success; make modifications.
*linker = std::move(matching_linker);
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Sequence Linker",
dragnn::runtime::SequenceLinker);
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_LINKER_H_
#define DRAGNN_RUNTIME_SEQUENCE_LINKER_H_
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Interface for link extraction for sequence inputs.
//
// This can be used to avoid ComputeSession overhead in simple cases; for
// example, extracting a sequence of identity or reverse-identity links.
class SequenceLinker : public RegisterableClass<SequenceLinker> {
public:
// Sets |linker| to an instance of the subclass named |name| initialized from
// the |channel| of the |component_spec|. On error, returns non-OK and
// modifies nothing.
static tensorflow::Status New(const string &name,
const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec,
std::unique_ptr<SequenceLinker> *linker);
SequenceLinker(const SequenceLinker &) = delete;
SequenceLinker &operator=(const SequenceLinker &) = delete;
virtual ~SequenceLinker() = default;
// Sets |name| to the registered name of the SequenceLinker that supports the
// |channel| of the |component_spec|. On error, returns non-OK and modifies
// nothing. The returned statuses include:
// * OK: If a supporting SequenceLinker was found.
// * INTERNAL: If an error occurred while searching for a compatible match.
// * NOT_FOUND: If the search was error-free, but no compatible match was
// found.
static tensorflow::Status Select(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec,
string *name);
// Overwrites |links| with the sequence of translated link step indices for
// the |input|. Specifically, sets links[i] to the (possibly out-of-bounds)
// step index to fetch from the source component for the i'th element of the
// target sequence. Assumes that |source_num_steps| is the number of steps
// taken by the source component. On error, returns non-OK.
virtual tensorflow::Status GetLinks(size_t source_num_steps,
InputBatchCache *input,
std::vector<int32> *links) const = 0;
protected:
SequenceLinker() = default;
private:
// Helps prevent use of the Create() method; use New() instead.
using RegisterableClass<SequenceLinker>::Create;
// Returns true if this supports the |channel| of the |component_spec|.
// Implementations must coordinate to ensure that at most one supports any
// given |component_spec|.
virtual bool Supports(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const = 0;
// Initializes this from the |channel| of the |component_spec|. On error,
// returns non-OK.
virtual tensorflow::Status Initialize(
const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) = 0;
};
} // namespace runtime
} // namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Sequence Linker",
dragnn::runtime::SequenceLinker);
} // namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::SequenceLinker, #subclass, subclass)
#endif // DRAGNN_RUNTIME_SEQUENCE_LINKER_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_linker.h"
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Supports components named "success" and initializes successfully.
class Success : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "success";
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(Success);
// Supports components named "failure" and fails to initialize.
class Failure : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "failure";
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::errors::Internal("Boom!");
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(Failure);
// Supports components named "duplicate" and initializes successfully.
class Duplicate : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &component_spec) const override {
return component_spec.name() == "duplicate";
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(Duplicate);
// Duplicate of the above.
using Duplicate2 = Duplicate;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(Duplicate2);
// Tests that a component can be successfully created.
TEST(SequenceLinkerTest, Success) {
string name;
std::unique_ptr<SequenceLinker> linker;
ComponentSpec component_spec;
component_spec.set_name("success");
TF_ASSERT_OK(SequenceLinker::Select({}, component_spec, &name));
ASSERT_EQ(name, "Success");
TF_EXPECT_OK(SequenceLinker::New(name, {}, component_spec, &linker));
EXPECT_NE(linker, nullptr);
}
// Tests that errors in Initialize() are reported.
TEST(SequenceLinkerTest, FailToInitialize) {
string name;
std::unique_ptr<SequenceLinker> linker;
ComponentSpec component_spec;
component_spec.set_name("failure");
TF_ASSERT_OK(SequenceLinker::Select({}, component_spec, &name));
EXPECT_EQ(name, "Failure");
EXPECT_THAT(SequenceLinker::New(name, {}, component_spec, &linker),
test::IsErrorWithSubstr("Boom!"));
EXPECT_EQ(linker, nullptr);
}
// Tests that unsupported specs are reported as NOT_FOUND errors.
TEST(SequenceLinkerTest, UnsupportedSpec) {
string name = "not overwritten";
ComponentSpec component_spec;
component_spec.set_name("unsupported");
EXPECT_THAT(
SequenceLinker::Select({}, component_spec, &name),
test::IsErrorWithCodeAndSubstr(tensorflow::error::NOT_FOUND,
"No SequenceLinker supports channel"));
EXPECT_EQ(name, "not overwritten");
}
// Tests that unsupported subclass names are reported as errors.
TEST(SequenceLinkerTest, UnsupportedSubclass) {
std::unique_ptr<SequenceLinker> linker;
ComponentSpec component_spec;
EXPECT_THAT(
SequenceLinker::New("Unsupported", {}, component_spec, &linker),
test::IsErrorWithSubstr("Unknown DRAGNN Runtime Sequence Linker"));
EXPECT_EQ(linker, nullptr);
}
// Tests that multiple supporting linkers are reported as INTERNAL errors.
TEST(SequenceLinkerTest, Duplicate) {
string name = "not overwritten";
ComponentSpec component_spec;
component_spec.set_name("duplicate");
EXPECT_THAT(SequenceLinker::Select({}, component_spec, &name),
test::IsErrorWithCodeAndSubstr(
tensorflow::error::INTERNAL,
"Multiple SequenceLinkers support channel"));
EXPECT_EQ(name, "not overwritten");
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_links.h"
#include <utility>
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status SequenceLinkManager::Reset(
const LinkedEmbeddingManager *linked_embedding_manager,
const ComponentSpec &component_spec,
const std::vector<string> &sequence_linker_types) {
const size_t num_channels = linked_embedding_manager->channel_configs_.size();
if (component_spec.linked_feature_size() != num_channels) {
return tensorflow::errors::InvalidArgument(
"Channel mismatch between LinkedEmbeddingManager (", num_channels,
") and ComponentSpec (", component_spec.linked_feature_size(), ")");
}
if (sequence_linker_types.size() != num_channels) {
return tensorflow::errors::InvalidArgument(
"Channel mismatch between LinkedEmbeddingManager (", num_channels,
") and SequenceLinkers (", sequence_linker_types.size(), ")");
}
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.embedding_dim() >= 0) {
return tensorflow::errors::Unimplemented(
"Transformed linked features are not supported for channel: ",
channel.ShortDebugString());
}
}
std::vector<ChannelConfig> local_configs; // avoid modification on error
for (size_t channel_id = 0; channel_id < num_channels; ++channel_id) {
const LinkedFeatureChannel &channel =
component_spec.linked_feature(channel_id);
local_configs.emplace_back();
ChannelConfig &channel_config = local_configs.back();
channel_config.is_recurrent =
channel.source_component() == component_spec.name();
channel_config.handle =
linked_embedding_manager->channel_configs_[channel_id].source_handle;
TF_RETURN_IF_ERROR(
SequenceLinker::New(sequence_linker_types[channel_id],
component_spec.linked_feature(channel_id),
component_spec, &channel_config.linker));
}
// Success; make modifications.
zeros_ = linked_embedding_manager->zeros_.view();
channel_configs_ = std::move(local_configs);
return tensorflow::Status::OK();
}
tensorflow::Status SequenceLinks::Reset(bool add_steps,
const SequenceLinkManager *manager,
NetworkStates *network_states,
InputBatchCache *input) {
zeros_ = manager->zeros_;
num_channels_ = manager->channel_configs_.size();
num_steps_ = 0;
bool have_num_steps = false; // true if |num_steps_| was assigned
// Make sure |channels_| is big enough. Note that |channels_| never shrinks,
// so the Channel.links sub-vector is never deallocated.
if (num_channels_ > channels_.size()) channels_.resize(num_channels_);
// Process non-recurrent links first.
for (int channel_id = 0; channel_id < num_channels_; ++channel_id) {
const SequenceLinkManager::ChannelConfig &channel_config =
manager->channel_configs_[channel_id];
if (channel_config.is_recurrent) continue;
Channel &channel = channels_[channel_id];
channel.layer = network_states->GetLayer(channel_config.handle);
TF_RETURN_IF_ERROR(channel_config.linker->GetLinks(channel.layer.num_rows(),
input, &channel.links));
if (!have_num_steps) {
num_steps_ = channel.links.size();
have_num_steps = true;
} else if (channel.links.size() != num_steps_) {
return tensorflow::errors::FailedPrecondition(
"Inconsistent link sequence lengths at channel ID ", channel_id,
": got ", channel.links.size(), " but expected ", num_steps_);
}
}
// Add steps to the |network_states|, if requested.
if (add_steps) {
if (!have_num_steps) {
return tensorflow::errors::FailedPrecondition(
"Cannot infer the number of steps to add because there are no "
"non-recurrent links");
}
network_states->AddSteps(num_steps_);
}
// Process recurrent links. These require that the current component in the
// |network_states| has been sized to the proper number of steps.
for (int channel_id = 0; channel_id < num_channels_; ++channel_id) {
const SequenceLinkManager::ChannelConfig &channel_config =
manager->channel_configs_[channel_id];
if (!channel_config.is_recurrent) continue;
Channel &channel = channels_[channel_id];
channel.layer = network_states->GetLayer(channel_config.handle);
TF_RETURN_IF_ERROR(channel_config.linker->GetLinks(channel.layer.num_rows(),
input, &channel.links));
if (!have_num_steps) {
num_steps_ = channel.links.size();
have_num_steps = true;
} else if (channel.links.size() != num_steps_) {
return tensorflow::errors::FailedPrecondition(
"Inconsistent link sequence lengths at channel ID ", channel_id,
": got ", channel.links.size(), " but expected ", num_steps_);
}
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for configuring and extracting linked embeddings for sequence-based
// models. Analogous to LinkedEmbeddingManager and LinkedEmbeddings, but uses
// SequenceLinker instead of ComputeSession.
#ifndef DRAGNN_RUNTIME_SEQUENCE_LINKS_H_
#define DRAGNN_RUNTIME_SEQUENCE_LINKS_H_
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Manager for linked embeddings for sequence-based models. This is a wrapper
// around the LinkedEmbeddingManager.
class SequenceLinkManager {
public:
// Creates an empty manager.
SequenceLinkManager() = default;
// Resets this to wrap the |linked_embedding_manager|, which must outlive
// this. The |sequence_linker_types| should name one SequenceLinker subclass
// per channel; e.g., {"IdentitySequenceLinker", "ReversedSequenceLinker"}.
// This initializes each SequenceLinker from the |component_spec|. On error,
// returns non-OK and does not modify this.
tensorflow::Status Reset(
const LinkedEmbeddingManager *linked_embedding_manager,
const ComponentSpec &component_spec,
const std::vector<string> &sequence_linker_types);
// Accessors.
size_t num_channels() const { return channel_configs_.size(); }
private:
friend class SequenceLinks;
// Configuration for a single linked embedding channel.
struct ChannelConfig {
// Whether this link is recurrent.
bool is_recurrent = false;
// Handle to the source layer in the relevant NetworkStates.
LayerHandle<float> handle;
// Extractor for sequences of translated link indices.
std::unique_ptr<SequenceLinker> linker;
};
// Array of zeros that can be substituted for out-of-bounds embeddings. This
// is a reference to the corresponding array in the LinkedEmbeddingManager.
// See the large comment in linked_embeddings.cc for reference.
AlignedView zeros_;
// Ordered list of configurations for each channel.
std::vector<ChannelConfig> channel_configs_;
};
// A set of linked embeddings for a sequence-based model. Configured by a
// SequenceLinkManager.
class SequenceLinks {
public:
// Creates an empty set of embeddings.
SequenceLinks() = default;
// Resets this to the sequences of linked embeddings managed by the |manager|
// on the |input|. Retrieves layers from the |network_states|. The |manager|
// must live until this is destroyed or Reset(), and should not be modified
// during that time. If |add_steps| is true, then infers the number of steps
// from the non-recurrent links and adds steps to the |network_states| before
// processing the recurrent links. On error, returns non-OK.
//
// NB: Recurrent links are tricky, because the |network_states| must be filled
// with steps before processing recurrent links. There are two approaches:
// 1. Add steps to the |network_states| before calling Reset(). This only
// works if the component also has fixed features, which can be used to
// infer the number of steps.
// 2. Set |add_steps| to true, so steps are added during Reset(). This only
// works if the component also has non-recurrent links, which can be used
// to infer the number of steps.
// If a component only has recurrent links then neither of the above works,
// but such a component would be nonsensical: it recurses on itself with no
// external input.
tensorflow::Status Reset(bool add_steps, const SequenceLinkManager *manager,
NetworkStates *network_states,
InputBatchCache *input);
// Retrieves the linked embedding for the |target_index|'th element of the
// |channel_id|'th channel. Sets |embedding| to the linked embedding vector
// and sets |is_out_of_bounds| to true if the link is out of bounds.
void Get(size_t channel_id, size_t target_index, Vector<float> *embedding,
bool *is_out_of_bounds) const;
// Accessors.
size_t num_channels() const { return num_channels_; }
size_t num_steps() const { return num_steps_; }
private:
// Data associated with a single linked embedding channel.
struct Channel {
// Source layer activations.
Matrix<float> layer;
// Translated link indices for each step.
std::vector<int32> links;
};
// Zero vector from the most recent Reset().
AlignedView zeros_;
// Number of channels and steps from the most recent Reset().
size_t num_channels_ = 0;
size_t num_steps_ = 0;
// Ordered list of linked embedding channels. This may contain more than
// |num_channels_| entries, to avoid deallocation/reallocation cycles, but
// only the first |num_channels_| entries are valid.
std::vector<Channel> channels_;
};
// Implementation details below.
inline void SequenceLinks::Get(size_t channel_id, size_t target_index,
Vector<float> *embedding,
bool *is_out_of_bounds) const {
DCHECK_LT(channel_id, num_channels());
DCHECK_LT(target_index, num_steps());
const Channel &channel = channels_[channel_id];
const int32 link = channel.links[target_index];
*is_out_of_bounds = (link < 0 || link >= channel.layer.num_rows());
if (*is_out_of_bounds) {
*embedding = Vector<float>(zeros_, channel.layer.num_columns());
} else {
*embedding = channel.layer.row(link);
}
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_LINKS_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_links.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Dimensions of the layers in the network (see ResetManager() below).
const size_t kPrevious1LayerDim = 16;
const size_t kPrevious2LayerDim = 32;
const size_t kRecurrentLayerDim = 48;
// Number of transition steps to take in each component in the network.
const size_t kNumSteps = 10;
// A working one-channel ComponentSpec.
const char kSingleSpec[] = R"(linked_feature {
embedding_dim: -1
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
})";
// A working multi-channel ComponentSpec.
const char kMultiSpec[] = R"(linked_feature {
embedding_dim: -1
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'source_component_2'
source_layer: 'previous_2'
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'test_component'
source_layer: 'recurrent'
size: 1
})";
// A recurrent-only ComponentSpec.
const char kRecurrentSpec[] = R"(linked_feature {
embedding_dim: -1
source_component: 'test_component'
source_layer: 'recurrent'
size: 1
})";
// Fails to initialize.
class FailToInitialize : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &component_spec) const override {
LOG(FATAL) << "Should never be called.";
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::errors::Internal("No initialization for you!");
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *) const override {
LOG(FATAL) << "Should never be called.";
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(FailToInitialize);
// Initializes OK, then fails to extract links.
class FailToGetLinks : public FailToInitialize {
public:
// Implements SequenceLinker.
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::errors::Internal("No links for you!");
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(FailToGetLinks);
// Initializes OK and links to the previous step.
class LinkToPrevious : public FailToGetLinks {
public:
// Implements SequenceLinker.
tensorflow::Status GetLinks(size_t source_num_steps, InputBatchCache *,
std::vector<int32> *links) const override {
links->resize(source_num_steps);
for (int i = 0; i < links->size(); ++i) (*links)[i] = i - 1;
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(LinkToPrevious);
// Initializes OK but produces the wrong number of links.
class WrongNumberOfLinks : public FailToGetLinks {
public:
// Implements SequenceLinker.
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *links) const override {
links->resize(kNumSteps + 1);
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(WrongNumberOfLinks);
class SequenceLinkManagerTest : public NetworkTestBase {
protected:
// Sets up previous components and layers.
void AddComponentsAndLayers() {
AddComponent("source_component_0");
AddComponent("source_component_1");
AddLayer("previous_1", kPrevious1LayerDim);
AddComponent("source_component_2");
AddLayer("previous_2", kPrevious2LayerDim);
AddComponent(kTestComponentName);
AddLayer("recurrent", kRecurrentLayerDim);
}
// Creates a SequenceLinkManager and returns the result of Reset()-ing it
// using the |component_spec_text|.
tensorflow::Status ResetManager(
const string &component_spec_text,
const std::vector<string> &sequence_linker_types) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
component_spec.set_name(kTestComponentName);
AddComponentsAndLayers();
TF_RETURN_IF_ERROR(linked_embedding_manager_.Reset(
component_spec, &variable_store_, &network_state_manager_));
return manager_.Reset(&linked_embedding_manager_, component_spec,
sequence_linker_types);
}
LinkedEmbeddingManager linked_embedding_manager_;
SequenceLinkManager manager_;
};
// Tests that SequenceLinkManager is empty by default.
TEST_F(SequenceLinkManagerTest, EmptyByDefault) {
EXPECT_EQ(manager_.num_channels(), 0);
}
// Tests that SequenceLinkManager is empty when reset to an empty spec.
TEST_F(SequenceLinkManagerTest, EmptySpec) {
TF_EXPECT_OK(ResetManager("", {}));
EXPECT_EQ(manager_.num_channels(), 0);
}
// Tests that SequenceLinkManager works with a single channel.
TEST_F(SequenceLinkManagerTest, OneChannel) {
TF_EXPECT_OK(ResetManager(kSingleSpec, {"LinkToPrevious"}));
EXPECT_EQ(manager_.num_channels(), 1);
}
// Tests that SequenceLinkManager works with multiple channels.
TEST_F(SequenceLinkManagerTest, MultipleChannels) {
TF_EXPECT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "LinkToPrevious", "LinkToPrevious"}));
EXPECT_EQ(manager_.num_channels(), 3);
}
// Tests that SequenceLinkManager fails if the LinkedEmbeddingManager and
// ComponentSpec are mismatched.
TEST_F(SequenceLinkManagerTest, MismatchedLinkedManagerAndComponentSpec) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(kMultiSpec, &component_spec));
component_spec.set_name(kTestComponentName);
AddComponentsAndLayers();
TF_ASSERT_OK(linked_embedding_manager_.Reset(component_spec, &variable_store_,
&network_state_manager_));
// Remove one linked feature, resulting in a mismatch.
component_spec.mutable_linked_feature()->RemoveLast();
EXPECT_THAT(
manager_.Reset(&linked_embedding_manager_, component_spec,
{"LinkToPrevious", "LinkToPrevious", "LinkToPrevious"}),
test::IsErrorWithSubstr("Channel mismatch between LinkedEmbeddingManager "
"(3) and ComponentSpec (2)"));
}
// Tests that SequenceLinkManager fails if the LinkedEmbeddingManager and
// SequenceLinkers are mismatched.
TEST_F(SequenceLinkManagerTest, MismatchedLinkedManagerAndSequenceLinkers) {
EXPECT_THAT(
ResetManager(kMultiSpec, {"LinkToPrevious", "LinkToPrevious"}),
test::IsErrorWithSubstr("Channel mismatch between LinkedEmbeddingManager "
"(3) and SequenceLinkers (2)"));
}
// Tests that SequenceLinkManager fails when the link is transformed.
TEST_F(SequenceLinkManagerTest, UnsupportedTransformedLink) {
const string kBadSpec = R"(linked_feature {
embedding_dim: 16 # bad
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
})";
AddLinkedWeightMatrix(0, kPrevious1LayerDim, 16, 0.0);
AddLinkedOutOfBoundsVector(0, 16, 0.0);
EXPECT_THAT(
ResetManager(kBadSpec, {"LinkToPrevious"}),
test::IsErrorWithSubstr("Transformed linked features are not supported"));
}
// Tests that SequenceLinkManager fails if one of the SequenceLinkers fails to
// initialize.
TEST_F(SequenceLinkManagerTest, FailToInitializeSequenceLinker) {
EXPECT_THAT(ResetManager(kMultiSpec, {"LinkToPrevious", "FailToInitialize",
"LinkToPrevious"}),
test::IsErrorWithSubstr("No initialization for you!"));
}
// Tests that SequenceLinkManager is OK even if the SequenceLinkers would fail
// in GetLinks().
TEST_F(SequenceLinkManagerTest, ManagerDoesntCareAboutGetLinks) {
TF_EXPECT_OK(ResetManager(
kMultiSpec, {"FailToGetLinks", "FailToGetLinks", "FailToGetLinks"}));
}
// Values to fill each layer with.
const float kPrevious1LayerValue = 1.0;
const float kPrevious2LayerValue = 2.0;
const float kRecurrentLayerValue = 3.0;
class SequenceLinksTest : public SequenceLinkManagerTest {
protected:
// Resets the |sequence_links_| using the |manager_|, |network_states_|, and
// |input_batch_cache_|, and returns the resulting status. Passes |add_steps|
// to Reset() and advances the current component by |num_steps|.
tensorflow::Status ResetLinks(bool add_steps = false,
size_t num_steps = kNumSteps) {
network_states_.Reset(&network_state_manager_);
// Fill components with steps.
StartComponent(kNumSteps); // source_component_0
StartComponent(kNumSteps); // source_component_1
StartComponent(kNumSteps); // source_component_2
StartComponent(num_steps); // current component
// Fill layers with values.
FillLayer("source_component_1", "previous_1", kPrevious1LayerValue);
FillLayer("source_component_2", "previous_2", kPrevious2LayerValue);
FillLayer(kTestComponentName, "recurrent", kRecurrentLayerValue);
return sequence_links_.Reset(add_steps, &manager_, &network_states_,
&input_batch_cache_);
}
InputBatchCache input_batch_cache_;
SequenceLinks sequence_links_;
};
// Tests that SequenceLinks is empty by default.
TEST_F(SequenceLinksTest, EmptyByDefault) {
EXPECT_EQ(sequence_links_.num_channels(), 0);
EXPECT_EQ(sequence_links_.num_steps(), 0);
}
// Tests that SequenceLinks is empty when reset by an empty manager.
TEST_F(SequenceLinksTest, EmptyManager) {
TF_ASSERT_OK(ResetManager("", {}));
TF_EXPECT_OK(ResetLinks());
EXPECT_EQ(sequence_links_.num_channels(), 0);
EXPECT_EQ(sequence_links_.num_steps(), 0);
}
// Tests that SequenceLinks fails when one of the non-recurrent SequenceLinkers
// fails.
TEST_F(SequenceLinksTest, FailToGetNonRecurrentLinks) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "FailToGetLinks", "LinkToPrevious"}));
EXPECT_THAT(ResetLinks(), test::IsErrorWithSubstr("No links for you!"));
}
// Tests that SequenceLinks fails when one of the recurrent SequenceLinkers
// fails.
TEST_F(SequenceLinksTest, FailToGetRecurrentLinks) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "LinkToPrevious", "FailToGetLinks"}));
EXPECT_THAT(ResetLinks(), test::IsErrorWithSubstr("No links for you!"));
}
// Tests that SequenceLinks fails when the non-recurrent SequenceLinkers produce
// different numbers of links.
TEST_F(SequenceLinksTest, MismatchedNumbersOfNonRecurrentLinks) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "WrongNumberOfLinks", "LinkToPrevious"}));
EXPECT_THAT(ResetLinks(),
test::IsErrorWithSubstr("Inconsistent link sequence lengths at "
"channel ID 1: got 11 but expected 10"));
}
// Tests that SequenceLinks fails when the recurrent SequenceLinkers produce
// different numbers of links.
TEST_F(SequenceLinksTest, MismatchedNumbersOfRecurrentLinks) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "LinkToPrevious", "WrongNumberOfLinks"}));
EXPECT_THAT(ResetLinks(),
test::IsErrorWithSubstr("Inconsistent link sequence lengths at "
"channel ID 2: got 11 but expected 10"));
}
// Tests that SequenceLinks works as expected on one channel.
TEST_F(SequenceLinksTest, SingleChannel) {
TF_ASSERT_OK(ResetManager(kSingleSpec, {"LinkToPrevious"}));
TF_ASSERT_OK(ResetLinks());
ASSERT_EQ(sequence_links_.num_channels(), 1);
ASSERT_EQ(sequence_links_.num_steps(), kNumSteps);
const Matrix<float> previous1(GetLayer("source_component_1", "previous_1"));
Vector<float> embedding;
bool is_out_of_bounds = false;
// LinkToPrevious links the 0'th index to -1, which is out of bounds.
sequence_links_.Get(0, 0, &embedding, &is_out_of_bounds);
EXPECT_TRUE(is_out_of_bounds);
ExpectVector(embedding, kPrevious1LayerDim, 0.0);
// The remaining links point to the previous item.
for (int i = 1; i < kNumSteps; ++i) {
sequence_links_.Get(0, i, &embedding, &is_out_of_bounds);
EXPECT_FALSE(is_out_of_bounds);
ExpectVector(embedding, kPrevious1LayerDim, kPrevious1LayerValue);
EXPECT_EQ(embedding.data(), previous1.row(i - 1).data());
}
}
// Tests that SequenceLinks works as expected on multiple channels.
TEST_F(SequenceLinksTest, ManyChannels) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "LinkToPrevious", "LinkToPrevious"}));
TF_ASSERT_OK(ResetLinks());
ASSERT_EQ(sequence_links_.num_channels(), 3);
ASSERT_EQ(sequence_links_.num_steps(), kNumSteps);
const Matrix<float> previous1(GetLayer("source_component_1", "previous_1"));
const Matrix<float> previous2(GetLayer("source_component_2", "previous_2"));
const Matrix<float> recurrent(GetLayer(kTestComponentName, "recurrent"));
Vector<float> embedding;
bool is_out_of_bounds = false;
// LinkToPrevious links the 0'th index to -1, which is out of bounds.
sequence_links_.Get(0, 0, &embedding, &is_out_of_bounds);
EXPECT_TRUE(is_out_of_bounds);
ExpectVector(embedding, kPrevious1LayerDim, 0.0);
sequence_links_.Get(1, 0, &embedding, &is_out_of_bounds);
EXPECT_TRUE(is_out_of_bounds);
ExpectVector(embedding, kPrevious2LayerDim, 0.0);
sequence_links_.Get(2, 0, &embedding, &is_out_of_bounds);
EXPECT_TRUE(is_out_of_bounds);
ExpectVector(embedding, kRecurrentLayerDim, 0.0);
// The remaining links point to the previous item.
for (int i = 1; i < kNumSteps; ++i) {
sequence_links_.Get(0, i, &embedding, &is_out_of_bounds);
EXPECT_FALSE(is_out_of_bounds);
ExpectVector(embedding, kPrevious1LayerDim, kPrevious1LayerValue);
EXPECT_EQ(embedding.data(), previous1.row(i - 1).data());
sequence_links_.Get(1, i, &embedding, &is_out_of_bounds);
EXPECT_FALSE(is_out_of_bounds);
ExpectVector(embedding, kPrevious2LayerDim, kPrevious2LayerValue);
EXPECT_EQ(embedding.data(), previous2.row(i - 1).data());
sequence_links_.Get(2, i, &embedding, &is_out_of_bounds);
EXPECT_FALSE(is_out_of_bounds);
ExpectVector(embedding, kRecurrentLayerDim, kRecurrentLayerValue);
EXPECT_EQ(embedding.data(), recurrent.row(i - 1).data());
}
}
// Tests that SequenceLinks is emptied when resetting to an empty manager after
// being reset to a non-empty manager.
TEST_F(SequenceLinksTest, ResetToEmptyAfterNonEmpty) {
TF_ASSERT_OK(ResetManager(kSingleSpec, {"LinkToPrevious"}));
TF_ASSERT_OK(ResetLinks());
ASSERT_EQ(sequence_links_.num_channels(), 1);
ASSERT_EQ(sequence_links_.num_steps(), kNumSteps);
SequenceLinkManager manager;
TF_ASSERT_OK(sequence_links_.Reset(/*add_steps=*/false, &manager,
&network_states_, &input_batch_cache_));
ASSERT_EQ(sequence_links_.num_channels(), 0);
ASSERT_EQ(sequence_links_.num_steps(), 0);
}
// Tests that SequenceLinks fails when adding steps to a component with no
// non-recurrent links.
TEST_F(SequenceLinksTest, AddStepsWithNoNonRecurrentLinks) {
TF_ASSERT_OK(ResetManager(kRecurrentSpec, {"LinkToPrevious"}));
EXPECT_THAT(
ResetLinks(/*add_steps=*/true),
test::IsErrorWithSubstr("Cannot infer the number of steps to add because "
"there are no non-recurrent links"));
}
// Tests that SequenceLinks produces no links when processing a component with
// only recurrent links, and when the NetworkStates has no steps.
TEST_F(SequenceLinksTest, RecurrentLinksWithNoSteps) {
TF_ASSERT_OK(ResetManager(kRecurrentSpec, {"LinkToPrevious"}));
TF_ASSERT_OK(ResetLinks(/*add_steps=*/false, /*num_steps=*/0));
ASSERT_EQ(sequence_links_.num_channels(), 1);
ASSERT_EQ(sequence_links_.num_steps(), 0);
}
// Tests that SequenceLinks properly infers the number of steps and adds them
// when processing a component with both non-recurrent and recurrent links.
TEST_F(SequenceLinksTest, AddStepsWithNonRecurrentAndRecurrentLinks) {
TF_ASSERT_OK(ResetManager(
kMultiSpec, {"LinkToPrevious", "LinkToPrevious", "LinkToPrevious"}));
TF_ASSERT_OK(ResetLinks(/*add_steps=*/true, /*num_steps=*/0));
ASSERT_EQ(sequence_links_.num_channels(), 3);
ASSERT_EQ(sequence_links_.num_steps(), kNumSteps);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment