Commit edea2b67 authored by Terry Koo's avatar Terry Koo
Browse files

Remove runtime because reasons.

parent a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_model.h"
#include <vector>
#include "dragnn/runtime/attributes.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Proper backend for sequence-based models.
constexpr char kSupportedBackend[] = "SequenceBackend";
// Attributes for sequence-based comopnents, attached to the component builder.
// See SequenceComponentTransformer.
struct ComponentBuilderAttributes : public Attributes {
// Registered names of the sequence extractors to use.
Mandatory<std::vector<string>> sequence_extractors{"sequence_extractors",
this};
// Registered names of the sequence linkers to use per channel, in order.
Mandatory<std::vector<string>> sequence_linkers{"sequence_linkers", this};
// Registered name of the sequence predictor to use.
Mandatory<string> sequence_predictor{"sequence_predictor", this};
};
} // namespace
bool SequenceModel::Supports(const ComponentSpec &component_spec) {
// Require single-embedding fixed and linked features.
for (const FixedFeatureChannel &channel : component_spec.fixed_feature()) {
if (channel.size() != 1) return false;
}
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.size() != 1) return false;
}
const bool has_fixed_feature = component_spec.fixed_feature_size() > 0;
bool has_recurrent_link = false;
bool has_non_recurrent_link = false;
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.source_component() == component_spec.name()) {
has_recurrent_link = true;
} else {
has_non_recurrent_link = true;
}
}
// Recurrent links must be accompanied by fixed features or non-recurrent
// links, so the number of recurrent steps can be pre-computed.
if (has_recurrent_link && !has_fixed_feature && !has_non_recurrent_link) {
return false;
}
const int num_features = component_spec.fixed_feature_size() +
component_spec.linked_feature_size();
return component_spec.backend().registered_name() == kSupportedBackend &&
num_features > 0;
}
tensorflow::Status SequenceModel::Initialize(
const ComponentSpec &component_spec, const string &logits_name,
const FixedEmbeddingManager *fixed_embedding_manager,
const LinkedEmbeddingManager *linked_embedding_manager,
NetworkStateManager *network_state_manager) {
component_name_ = component_spec.name();
if (component_spec.backend().registered_name() != kSupportedBackend) {
return tensorflow::errors::InvalidArgument(
"Invalid component backend: ",
component_spec.backend().registered_name());
}
TransitionSystemTraits traits(component_spec);
deterministic_ = traits.is_deterministic;
left_to_right_ = traits.is_left_to_right;
ComponentBuilderAttributes component_builder_attributes;
TF_RETURN_IF_ERROR(component_builder_attributes.Reset(
component_spec.component_builder().parameters()));
TF_RETURN_IF_ERROR(sequence_feature_manager_.Reset(
fixed_embedding_manager, component_spec,
component_builder_attributes.sequence_extractors()));
TF_RETURN_IF_ERROR(sequence_link_manager_.Reset(
linked_embedding_manager, component_spec,
component_builder_attributes.sequence_linkers()));
have_fixed_features_ = sequence_feature_manager_.num_channels() > 0;
have_linked_features_ = sequence_link_manager_.num_channels() > 0;
if (!have_fixed_features_ && !have_linked_features_) {
return tensorflow::errors::InvalidArgument("No fixed or linked features");
}
if (!deterministic_) {
size_t dimension = 0;
TF_RETURN_IF_ERROR(network_state_manager->LookupLayer(
component_name_, logits_name, &dimension, &logits_handle_));
if (dimension != component_spec.num_actions()) {
return tensorflow::errors::InvalidArgument(
"Logits dimension mismatch between NetworkStates (", dimension,
") and ComponentSpec (", component_spec.num_actions(), ")");
}
TF_RETURN_IF_ERROR(SequencePredictor::New(
component_builder_attributes.sequence_predictor(), component_spec,
&sequence_predictor_));
}
return tensorflow::Status::OK();
}
tensorflow::Status SequenceModel::Preprocess(
SessionState *session_state, ComputeSession *compute_session,
EvaluateState *evaluate_state) const {
InputBatchCache *input_batch_cache = compute_session->GetInputBatchCache();
if (input_batch_cache == nullptr) {
return tensorflow::errors::InvalidArgument("Null input batch");
}
// The feature handling below is complicated by the need to support recurrent
// links. See the comment on SequenceLinks::Reset().
NetworkStates &network_states = session_state->network_states;
TF_RETURN_IF_ERROR(evaluate_state->features.Reset(&sequence_feature_manager_,
input_batch_cache));
if (have_fixed_features_) {
network_states.AddSteps(evaluate_state->features.num_steps());
}
TF_RETURN_IF_ERROR(evaluate_state->links.Reset(
/*add_steps=*/!have_fixed_features_, &sequence_link_manager_,
&network_states, input_batch_cache));
// Initialize() ensures that there is at least one fixed or linked feature;
// use it to determine the number of steps.
size_t num_steps = 0;
if (have_fixed_features_ && have_linked_features_) {
num_steps = evaluate_state->features.num_steps();
if (num_steps != evaluate_state->links.num_steps()) {
return tensorflow::errors::FailedPrecondition(
"Sequence length mismatch between fixed features (", num_steps,
") and linked features (", evaluate_state->links.num_steps(), ")");
}
} else if (have_fixed_features_) {
num_steps = evaluate_state->features.num_steps();
} else {
num_steps = evaluate_state->links.num_steps();
}
// Tell the backend the current input size, so it can handle requests for
// linked features from downstream components.
static_cast<SequenceBackend *>(
compute_session->GetReadiedComponent(component_name_))
->SetSequenceSize(num_steps);
evaluate_state->num_steps = num_steps;
evaluate_state->input = input_batch_cache;
return tensorflow::Status::OK();
}
tensorflow::Status SequenceModel::Predict(const NetworkStates &network_states,
EvaluateState *evaluate_state) const {
if (!deterministic_) {
const Matrix<float> logits(network_states.GetLayer(logits_handle_));
TF_RETURN_IF_ERROR(
sequence_predictor_->Predict(logits, evaluate_state->input));
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_MODEL_H_
#define DRAGNN_RUNTIME_SEQUENCE_MODEL_H_
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_features.h"
#include "dragnn/runtime/sequence_links.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/session_state.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A class that configures and helps evaluate a sequence-based model.
//
// This class requires the SequenceBackend component backend and elides most of
// the ComputeSession feature extraction and transition system overhead.
class SequenceModel {
public:
// State associated with a single evaluation of the model.
struct EvaluateState {
// Number of transition steps in the current sequence.
size_t num_steps = 0;
// Current input batch.
InputBatchCache *input = nullptr;
// Sequence-based fixed features.
SequenceFeatures features;
// Sequence-based linked embeddings.
SequenceLinks links;
};
// Creates an uninitialized model. Call Initialize() before use.
SequenceModel() = default;
// Returns true if the |component_spec| is compatible with a sequence model.
static bool Supports(const ComponentSpec &component_spec);
// Initalizes this from the configuration in the |component_spec|. Wraps the
// |fixed_embedding_manager| and |linked_embedding_manager| in sequence-based
// versions, and requests layers from the |network_state_manager|. All of the
// managers must outlive this. If the transition system is non-deterministic,
// uses the layer named |logits_name| to make predictions later in Predict();
// otherwise, |logits_name| is ignored and Predict() does nothing. On error,
// returns non-OK.
tensorflow::Status Initialize(
const ComponentSpec &component_spec, const string &logits_name,
const FixedEmbeddingManager *fixed_embedding_manager,
const LinkedEmbeddingManager *linked_embedding_manager,
NetworkStateManager *network_state_manager);
// Resets the |evaluate_state| to values derived from the |session_state| and
// |compute_session|. Also updates the NetworkStates in the |session_state|
// and the current component of the |compute_session| with the length of the
// current sequence. Call this before producing output layers. On error,
// returns non-OK.
tensorflow::Status Preprocess(SessionState *session_state,
ComputeSession *compute_session,
EvaluateState *evaluate_state) const;
// If applicable, makes predictions based on the logits in |network_states|
// and applies them to the input in the |evaluate_state|. Call this after
// producing output layers. On error, returns non-OK.
tensorflow::Status Predict(const NetworkStates &network_states,
EvaluateState *evaluate_state) const;
// Accessors.
bool deterministic() const { return deterministic_; }
bool left_to_right() const { return left_to_right_; }
const SequenceLinkManager &sequence_link_manager() const;
const SequenceFeatureManager &sequence_feature_manager() const;
private:
// Name of the component that this model is a part of.
string component_name_;
// Whether the underlying transition system is deterministic.
bool deterministic_ = false;
// Whether to process sequences from left to right.
bool left_to_right_ = true;
// Whether fixed or linked features are present.
bool have_fixed_features_ = false;
bool have_linked_features_ = false;
// Handle to the logits layer. Only used if |deterministic_| is false.
LayerHandle<float> logits_handle_;
// Manager for sequence-based feature extractors.
SequenceFeatureManager sequence_feature_manager_;
// Manager for sequence-based linked embeddings.
SequenceLinkManager sequence_link_manager_;
// Sequence-based predictor, if |deterministic_| is false.
std::unique_ptr<SequencePredictor> sequence_predictor_;
};
// Implementation details below.
inline const SequenceLinkManager &SequenceModel::sequence_link_manager() const {
return sequence_link_manager_;
}
inline const SequenceFeatureManager &SequenceModel::sequence_feature_manager()
const {
return sequence_feature_manager_;
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_MODEL_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_model.h"
#include <string>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::Return;
constexpr int kNumSteps = 50;
constexpr int kVocabularySize = 123;
constexpr int kLinkedDim = 11;
constexpr int kLogitsDim = 17;
constexpr char kLogitsName[] = "oddly_named_logits";
constexpr char kPreviousComponentName[] = "previous_component";
constexpr char kPreviousLayerName[] = "previous_layer";
constexpr float kPreviousLayerValue = -1.0;
// Sequence extractor that extracts [0, 2, 4, ...].
class EvenNumbers : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &) const override {
return true;
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *ids) const override {
ids->clear();
for (int i = 0; i < num_steps_; ++i) ids->push_back(2 * i);
return tensorflow::Status::OK();
}
// Sets the number of steps to emit.
static void SetNumSteps(int num_steps) { num_steps_ = num_steps; }
private:
// The number of steps to produce.
static int num_steps_;
};
int EvenNumbers::num_steps_ = kNumSteps;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(EvenNumbers);
// Trivial linker that links each index to the previous one.
class LinkToPrevious : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &) const override {
return true;
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *links) const override {
links->clear();
for (int i = 0; i < num_steps_; ++i) links->push_back(i - 1);
return tensorflow::Status::OK();
}
// Sets the number of steps to emit.
static void SetNumSteps(int num_steps) { num_steps_ = num_steps; }
private:
// The number of steps to produce.
static int num_steps_;
};
int LinkToPrevious::num_steps_ = kNumSteps;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(LinkToPrevious);
// Trivial predictor that captures the prediction logits.
class CaptureLogits : public SequencePredictor {
public:
// Implements SequenceLinker.
bool Supports(const ComponentSpec &) const override { return true; }
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status Predict(Matrix<float> logits,
InputBatchCache *) const override {
GetLogits() = logits;
return tensorflow::Status::OK();
}
// Returns the captured logits.
static Matrix<float> &GetLogits() {
static auto *logits = new Matrix<float>();
return *logits;
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(CaptureLogits);
class SequenceModelTest : public NetworkTestBase {
protected:
// Adds default call expectations. Since these are added first, they can be
// overridden by call expectations in individual tests.
SequenceModelTest() {
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input_));
EXPECT_CALL(compute_session_, GetReadiedComponent(kTestComponentName))
.WillRepeatedly(Return(&backend_));
// Some tests overwrite these; ensure that they are restored to the normal
// values at the start of each test.
EvenNumbers::SetNumSteps(kNumSteps);
LinkToPrevious::SetNumSteps(kNumSteps);
CaptureLogits::GetLogits() = Matrix<float>();
}
// Initializes the |model_| and its underlying feature managers from the
// |component_spec|, then uses the |model_| to preprocess and predict the
// |input_|. Also sets each row of the logits to twice its row index. On
// error, returns non-OK.
tensorflow::Status Run(ComponentSpec component_spec) {
component_spec.set_name(kTestComponentName);
AddComponent(kPreviousComponentName);
AddLayer(kPreviousLayerName, kLinkedDim);
AddComponent(kTestComponentName);
AddLayer(kLogitsName, kLogitsDim);
TF_RETURN_IF_ERROR(fixed_embedding_manager_.Reset(
component_spec, &variable_store_, &network_state_manager_));
TF_RETURN_IF_ERROR(linked_embedding_manager_.Reset(
component_spec, &variable_store_, &network_state_manager_));
TF_RETURN_IF_ERROR(model_.Initialize(
component_spec, kLogitsName, &fixed_embedding_manager_,
&linked_embedding_manager_, &network_state_manager_));
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
FillLayer(kPreviousComponentName, kPreviousLayerName, kPreviousLayerValue);
StartComponent(0);
TF_RETURN_IF_ERROR(model_.Preprocess(&session_state_, &compute_session_,
&evaluate_state_));
MutableMatrix<float> logits = GetLayer(kTestComponentName, kLogitsName);
for (int row = 0; row < logits.num_rows(); ++row) {
for (int column = 0; column < logits.num_columns(); ++column) {
logits.row(row)[column] = 2.0 * row;
}
}
return model_.Predict(network_states_, &evaluate_state_);
}
// Returns the sequence size passed to the |backend_|.
int GetBackendSequenceSize() {
// The sequence size is not directly exposed, but can be inferred using one
// of the reverse step translators.
return backend_.GetStepLookupFunction("reverse-token")(0, 0, 0) + 1;
}
// Fixed and linked embedding managers.
FixedEmbeddingManager fixed_embedding_manager_;
LinkedEmbeddingManager linked_embedding_manager_;
// Input batch injected into Preprocess() by default.
InputBatchCache input_;
// Backend injected into Preprocess().
SequenceBackend backend_;
// Sequence-based model.
SequenceModel model_;
// Per-evaluation state.
SequenceModel::EvaluateState evaluate_state_;
};
// Returns a ComponentSpec that is supported.
ComponentSpec MakeSupportedSpec() {
ComponentSpec component_spec;
component_spec.set_num_actions(kLogitsDim);
component_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_extractors", "EvenNumbers"});
component_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_linkers", "LinkToPrevious"});
component_spec.mutable_component_builder()->mutable_parameters()->insert(
{"sequence_predictor", "CaptureLogits"});
component_spec.mutable_backend()->set_registered_name("SequenceBackend");
FixedFeatureChannel *fixed_feature = component_spec.add_fixed_feature();
fixed_feature->set_size(1);
fixed_feature->set_embedding_dim(-1);
LinkedFeatureChannel *linked_feature = component_spec.add_linked_feature();
linked_feature->set_source_component(kPreviousComponentName);
linked_feature->set_source_layer(kPreviousLayerName);
linked_feature->set_size(1);
linked_feature->set_embedding_dim(-1);
return component_spec;
}
// Tests that the model supports a supported spec.
TEST_F(SequenceModelTest, Supported) {
const ComponentSpec component_spec = MakeSupportedSpec();
EXPECT_TRUE(SequenceModel::Supports(component_spec));
}
// Tests that the model rejects a spec with the wrong backend.
TEST_F(SequenceModelTest, UnsupportedBackend) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_backend()->set_registered_name("bad");
EXPECT_FALSE(SequenceModel::Supports(component_spec));
}
// Tests that the model rejects a spec with no features.
TEST_F(SequenceModelTest, UnsupportedNoFeatures) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.clear_fixed_feature();
component_spec.clear_linked_feature();
EXPECT_FALSE(SequenceModel::Supports(component_spec));
}
// Tests that the model rejects a spec with a multi-embedding fixed feature.
TEST_F(SequenceModelTest, UnsupportedMultiEmbeddingFixedFeature) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_fixed_feature(0)->set_size(2);
EXPECT_FALSE(SequenceModel::Supports(component_spec));
}
// Tests that the model rejects a spec with a multi-embedding linked feature.
TEST_F(SequenceModelTest, UnsupportedMultiEmbeddingLinkedFeature) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_linked_feature(0)->set_size(2);
EXPECT_FALSE(SequenceModel::Supports(component_spec));
}
// Tests that the model rejects a spec with only recurrent links.
TEST_F(SequenceModelTest, UnsupportedOnlyRecurrentLinks) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.set_name("foo");
component_spec.clear_fixed_feature();
component_spec.mutable_linked_feature(0)->set_source_component("foo");
EXPECT_FALSE(SequenceModel::Supports(component_spec));
}
// Tests that Initialize() succeeds on a supported spec.
TEST_F(SequenceModelTest, InitializeSupported) {
const ComponentSpec component_spec = MakeSupportedSpec();
TF_ASSERT_OK(Run(component_spec));
EXPECT_FALSE(model_.deterministic());
EXPECT_TRUE(model_.left_to_right());
EXPECT_EQ(model_.sequence_feature_manager().num_channels(), 1);
EXPECT_EQ(model_.sequence_link_manager().num_channels(), 1);
}
// Tests that Initialize() detects deterministic components.
TEST_F(SequenceModelTest, InitializeDeterministic) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.set_num_actions(1);
TF_ASSERT_OK(Run(component_spec));
EXPECT_TRUE(model_.deterministic());
EXPECT_TRUE(model_.left_to_right());
EXPECT_EQ(model_.sequence_feature_manager().num_channels(), 1);
EXPECT_EQ(model_.sequence_link_manager().num_channels(), 1);
}
// Tests that Initialize() detects right-to-left components.
TEST_F(SequenceModelTest, InitializeLeftToRight) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_transition_system()->mutable_parameters()->insert(
{"left_to_right", "false"});
TF_ASSERT_OK(Run(component_spec));
EXPECT_FALSE(model_.deterministic());
EXPECT_FALSE(model_.left_to_right());
EXPECT_EQ(model_.sequence_feature_manager().num_channels(), 1);
EXPECT_EQ(model_.sequence_link_manager().num_channels(), 1);
}
// Tests that Initialize() fails if the backend is wrong.
TEST_F(SequenceModelTest, WrongBackend) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_backend()->set_registered_name("bad");
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("Invalid component backend"));
}
// Tests that Initialize() fails if the number of actions in the ComponentSpec
// does not match the logits.
TEST_F(SequenceModelTest, WrongNumActions) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.set_num_actions(kLogitsDim + 1);
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("Logits dimension mismatch"));
}
// Tests that Initialize() fails if an unknown sequence extractor is specified.
TEST_F(SequenceModelTest, UnknownSequenceExtractor) {
ComponentSpec component_spec = MakeSupportedSpec();
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_extractors"] = "bad";
EXPECT_THAT(
Run(component_spec),
test::IsErrorWithSubstr("Unknown DRAGNN Runtime Sequence Extractor"));
}
// Tests that Initialize() fails if an unknown sequence linker is specified.
TEST_F(SequenceModelTest, UnknownSequenceLinker) {
ComponentSpec component_spec = MakeSupportedSpec();
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_linkers"] = "bad";
EXPECT_THAT(
Run(component_spec),
test::IsErrorWithSubstr("Unknown DRAGNN Runtime Sequence Linker"));
}
// Tests that Initialize() fails if an unknown sequence predictor is specified.
TEST_F(SequenceModelTest, UnknownSequencePredictor) {
ComponentSpec component_spec = MakeSupportedSpec();
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_predictor"] = "bad";
EXPECT_THAT(
Run(component_spec),
test::IsErrorWithSubstr("Unknown DRAGNN Runtime Sequence Predictor"));
}
// Tests that Initialize() fails on an unknown component builder parameter.
TEST_F(SequenceModelTest, UnknownComponentBuilderParameter) {
ComponentSpec component_spec = MakeSupportedSpec();
(*component_spec.mutable_component_builder()->mutable_parameters())["bad"] =
"bad";
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("Unknown attribute"));
}
// Tests that Initialize() fails if there are no fixed or linked features.
TEST_F(SequenceModelTest, InitializeRequiresFeatures) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.clear_fixed_feature();
component_spec.clear_linked_feature();
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_extractors"] = "";
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_linkers"] = "";
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("No fixed or linked features"));
}
// Tests that the model fails if a null batch is returned.
TEST_F(SequenceModelTest, NullBatch) {
EXPECT_CALL(compute_session_, GetInputBatchCache()).WillOnce(Return(nullptr));
EXPECT_THAT(Run(MakeSupportedSpec()),
test::IsErrorWithSubstr("Null input batch"));
}
// Tests that the model properly sets up the EvaluateState and logits.
TEST_F(SequenceModelTest, Success) {
TF_ASSERT_OK(Run(MakeSupportedSpec()));
EXPECT_EQ(GetBackendSequenceSize(), kNumSteps);
EXPECT_EQ(evaluate_state_.num_steps, kNumSteps);
EXPECT_EQ(evaluate_state_.input, &input_);
EXPECT_EQ(evaluate_state_.features.num_channels(), 1);
EXPECT_EQ(evaluate_state_.features.num_steps(), kNumSteps);
EXPECT_EQ(evaluate_state_.features.GetId(0, 0), 0);
EXPECT_EQ(evaluate_state_.features.GetId(0, 1), 2);
EXPECT_EQ(evaluate_state_.features.GetId(0, 2), 4);
EXPECT_EQ(evaluate_state_.links.num_channels(), 1);
EXPECT_EQ(evaluate_state_.links.num_steps(), kNumSteps);
Vector<float> embedding;
bool is_out_of_bounds = false;
evaluate_state_.links.Get(0, 0, &embedding, &is_out_of_bounds);
ExpectVector(embedding, kLinkedDim, 0.0);
EXPECT_TRUE(is_out_of_bounds);
evaluate_state_.links.Get(0, 1, &embedding, &is_out_of_bounds);
ExpectVector(embedding, kLinkedDim, kPreviousLayerValue);
EXPECT_FALSE(is_out_of_bounds);
const Matrix<float> logits = CaptureLogits::GetLogits();
ASSERT_EQ(logits.num_rows(), kNumSteps);
ASSERT_EQ(logits.num_columns(), kLogitsDim);
for (int i = 0; i < kNumSteps; ++i) {
ExpectVector(logits.row(i), kLogitsDim, 2.0 * i);
}
}
// Tests that the model works with only fixed features.
TEST_F(SequenceModelTest, FixedFeaturesOnly) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.clear_linked_feature();
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_linkers"] = "";
TF_ASSERT_OK(Run(component_spec));
EXPECT_EQ(GetBackendSequenceSize(), kNumSteps);
EXPECT_EQ(evaluate_state_.num_steps, kNumSteps);
EXPECT_EQ(evaluate_state_.input, &input_);
EXPECT_EQ(evaluate_state_.features.num_channels(), 1);
EXPECT_EQ(evaluate_state_.features.num_steps(), kNumSteps);
EXPECT_EQ(evaluate_state_.features.GetId(0, 0), 0);
EXPECT_EQ(evaluate_state_.features.GetId(0, 1), 2);
EXPECT_EQ(evaluate_state_.features.GetId(0, 2), 4);
EXPECT_EQ(evaluate_state_.links.num_channels(), 0);
EXPECT_EQ(evaluate_state_.links.num_steps(), 0);
const Matrix<float> logits = CaptureLogits::GetLogits();
ASSERT_EQ(logits.num_rows(), kNumSteps);
ASSERT_EQ(logits.num_columns(), kLogitsDim);
for (int i = 0; i < kNumSteps; ++i) {
ExpectVector(logits.row(i), kLogitsDim, 2.0 * i);
}
}
// Tests that the model works with only linked features.
TEST_F(SequenceModelTest, LinkedFeaturesOnly) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.clear_fixed_feature();
(*component_spec.mutable_component_builder()
->mutable_parameters())["sequence_extractors"] = "";
TF_ASSERT_OK(Run(component_spec));
EXPECT_EQ(GetBackendSequenceSize(), kNumSteps);
EXPECT_EQ(evaluate_state_.num_steps, kNumSteps);
EXPECT_EQ(evaluate_state_.input, &input_);
EXPECT_EQ(evaluate_state_.features.num_channels(), 0);
EXPECT_EQ(evaluate_state_.features.num_steps(), 0);
EXPECT_EQ(evaluate_state_.links.num_channels(), 1);
EXPECT_EQ(evaluate_state_.links.num_steps(), kNumSteps);
Vector<float> embedding;
bool is_out_of_bounds = false;
evaluate_state_.links.Get(0, 0, &embedding, &is_out_of_bounds);
ExpectVector(embedding, kLinkedDim, 0.0);
EXPECT_TRUE(is_out_of_bounds);
evaluate_state_.links.Get(0, 1, &embedding, &is_out_of_bounds);
ExpectVector(embedding, kLinkedDim, kPreviousLayerValue);
EXPECT_FALSE(is_out_of_bounds);
const Matrix<float> logits = CaptureLogits::GetLogits();
ASSERT_EQ(logits.num_rows(), kNumSteps);
ASSERT_EQ(logits.num_columns(), kLogitsDim);
for (int i = 0; i < kNumSteps; ++i) {
ExpectVector(logits.row(i), kLogitsDim, 2.0 * i);
}
}
// Tests that the model fails if the fixed and linked features disagree on the
// number of steps.
TEST_F(SequenceModelTest, FixedAndLinkedDisagree) {
EvenNumbers::SetNumSteps(5);
LinkToPrevious::SetNumSteps(6);
EXPECT_THAT(Run(MakeSupportedSpec()),
test::IsErrorWithSubstr("Sequence length mismatch between fixed "
"features (5) and linked features (6)"));
}
// Tests that the model can handle an empty sequence.
TEST_F(SequenceModelTest, EmptySequence) {
EvenNumbers::SetNumSteps(0);
LinkToPrevious::SetNumSteps(0);
TF_ASSERT_OK(Run(MakeSupportedSpec()));
EXPECT_EQ(GetBackendSequenceSize(), 0);
const Matrix<float> logits = CaptureLogits::GetLogits();
ASSERT_EQ(logits.num_rows(), 0);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_predictor.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status SequencePredictor::Select(
const ComponentSpec &component_spec, string *name) {
string supporting_name;
for (const Registry::Registrar *registrar = registry()->components;
registrar != nullptr; registrar = registrar->next()) {
Factory *factory_function = registrar->object();
std::unique_ptr<SequencePredictor> current_predictor(factory_function());
if (!current_predictor->Supports(component_spec)) continue;
if (!supporting_name.empty()) {
return tensorflow::errors::Internal(
"Multiple SequencePredictors support ComponentSpec (",
supporting_name, " and ", registrar->name(),
"): ", component_spec.ShortDebugString());
}
supporting_name = registrar->name();
}
if (supporting_name.empty()) {
return tensorflow::errors::NotFound(
"No SequencePredictor supports ComponentSpec: ",
component_spec.ShortDebugString());
}
// Success; make modifications.
*name = supporting_name;
return tensorflow::Status::OK();
}
tensorflow::Status SequencePredictor::New(
const string &name, const ComponentSpec &component_spec,
std::unique_ptr<SequencePredictor> *predictor) {
std::unique_ptr<SequencePredictor> matching_predictor;
TF_RETURN_IF_ERROR(
SequencePredictor::CreateOrError(name, &matching_predictor));
TF_RETURN_IF_ERROR(matching_predictor->Initialize(component_spec));
// Success; make modifications.
*predictor = std::move(matching_predictor);
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Sequence Predictor",
dragnn::runtime::SequencePredictor);
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_PREDICTOR_H_
#define DRAGNN_RUNTIME_SEQUENCE_PREDICTOR_H_
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/math/types.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Interface for making predictions on sequences.
//
// This predictor can be used to avoid ComputeSession overhead in simple cases;
// for example, predicting sequences of POS tags.
class SequencePredictor : public RegisterableClass<SequencePredictor> {
public:
// Sets |predictor| to an instance of the subclass named |name| initialized
// from the |component_spec|. On error, returns non-OK and modifies nothing.
static tensorflow::Status New(const string &name,
const ComponentSpec &component_spec,
std::unique_ptr<SequencePredictor> *predictor);
SequencePredictor(const SequencePredictor &) = delete;
SequencePredictor &operator=(const SequencePredictor &) = delete;
virtual ~SequencePredictor() = default;
// Sets |name| to the registered name of the SequencePredictor that supports
// the |component_spec|. On error, returns non-OK and modifies nothing. The
// returned statuses include:
// * OK: If a supporting SequencePredictor was found.
// * INTERNAL: If an error occurred while searching for a compatible match.
// * NOT_FOUND: If the search was error-free, but no compatible match was
// found.
static tensorflow::Status Select(const ComponentSpec &component_spec,
string *name);
// Makes a sequence of predictions using the per-step |logits| and writes
// annotations to the |input|.
virtual tensorflow::Status Predict(Matrix<float> logits,
InputBatchCache *input) const = 0;
protected:
SequencePredictor() = default;
private:
// Helps prevent use of the Create() method; use New() instead.
using RegisterableClass<SequencePredictor>::Create;
// Returns true if this supports the |component_spec|. Implementations must
// coordinate to ensure that at most one supports any given |component_spec|.
virtual bool Supports(const ComponentSpec &component_spec) const = 0;
// Initializes this from the |component_spec|. On error, returns non-OK.
virtual tensorflow::Status Initialize(
const ComponentSpec &component_spec) = 0;
};
} // namespace runtime
} // namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Sequence Predictor",
dragnn::runtime::SequencePredictor);
} // namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::SequencePredictor, #subclass, subclass)
#endif // DRAGNN_RUNTIME_SEQUENCE_PREDICTOR_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_predictor.h"
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/math/types.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Supports components named "success" and initializes successfully.
class Success : public SequencePredictor {
public:
// Implements SequencePredictor.
bool Supports(const ComponentSpec &component_spec) const override {
return component_spec.name() == "success";
}
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status Predict(Matrix<float>, InputBatchCache *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(Success);
// Supports components named "failure" and fails to initialize.
class Failure : public SequencePredictor {
public:
// Implements SequencePredictor.
bool Supports(const ComponentSpec &component_spec) const override {
return component_spec.name() == "failure";
}
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::errors::Internal("Boom!");
}
tensorflow::Status Predict(Matrix<float>, InputBatchCache *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(Failure);
// Supports components named "duplicate" and initializes successfully.
class Duplicate : public SequencePredictor {
public:
// Implements SequencePredictor.
bool Supports(const ComponentSpec &component_spec) const override {
return component_spec.name() == "duplicate";
}
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status Predict(Matrix<float>, InputBatchCache *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(Duplicate);
// Duplicate of the above.
using Duplicate2 = Duplicate;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(Duplicate2);
// Tests that a component can be successfully created.
TEST(SequencePredictorTest, Success) {
string name;
std::unique_ptr<SequencePredictor> predictor;
ComponentSpec component_spec;
component_spec.set_name("success");
TF_ASSERT_OK(SequencePredictor::Select(component_spec, &name));
ASSERT_EQ(name, "Success");
TF_EXPECT_OK(SequencePredictor::New(name, component_spec, &predictor));
EXPECT_NE(predictor, nullptr);
}
// Tests that errors in Initialize() are reported.
TEST(SequencePredictorTest, FailToInitialize) {
string name;
std::unique_ptr<SequencePredictor> predictor;
ComponentSpec component_spec;
component_spec.set_name("failure");
TF_ASSERT_OK(SequencePredictor::Select(component_spec, &name));
EXPECT_EQ(name, "Failure");
EXPECT_THAT(SequencePredictor::New(name, component_spec, &predictor),
test::IsErrorWithSubstr("Boom!"));
EXPECT_EQ(predictor, nullptr);
}
// Tests that unsupported specs are reported as NOT_FOUND errors.
TEST(SequencePredictorTest, UnsupportedSpec) {
string name = "not overwritten";
ComponentSpec component_spec;
component_spec.set_name("unsupported");
EXPECT_THAT(SequencePredictor::Select(component_spec, &name),
test::IsErrorWithCodeAndSubstr(
tensorflow::error::NOT_FOUND,
"No SequencePredictor supports ComponentSpec"));
EXPECT_EQ(name, "not overwritten");
}
// Tests that unsupported subclass names are reported as errors.
TEST(SequencePredictorTest, UnsupportedSubclass) {
std::unique_ptr<SequencePredictor> predictor;
ComponentSpec component_spec;
EXPECT_THAT(
SequencePredictor::New("Unsupported", component_spec, &predictor),
test::IsErrorWithSubstr("Unknown DRAGNN Runtime Sequence Predictor"));
EXPECT_EQ(predictor, nullptr);
}
// Tests that multiple supporting predictors are reported as INTERNAL errors.
TEST(SequencePredictorTest, Duplicate) {
string name = "not overwritten";
ComponentSpec component_spec;
component_spec.set_name("duplicate");
EXPECT_THAT(SequencePredictor::Select(component_spec, &name),
test::IsErrorWithCodeAndSubstr(
tensorflow::error::INTERNAL,
"Multiple SequencePredictors support ComponentSpec"));
EXPECT_EQ(name, "not overwritten");
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SESSION_STATE_H_
#define DRAGNN_RUNTIME_SESSION_STATE_H_
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// State associated with a ComputeSession being evaluated by a DRAGNN network,
// reusable across multiple evaluations. Unlike the ComputeSession, which is
// both the input and output of the network, this state is strictly internal to
// the network. Production code should allocate these via a SessionStatePool.
struct SessionState {
// The network states that connect the pipeline of components.
NetworkStates network_states;
// Generic set of typed extensions.
Extensions extensions;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_SESSION_STATE_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/session_state_pool.h"
#include <algorithm>
namespace syntaxnet {
namespace dragnn {
namespace runtime {
SessionStatePool::SessionStatePool(size_t max_free_states)
: max_free_states_(max_free_states) {}
std::unique_ptr<SessionState> SessionStatePool::Acquire() {
{ // Exclude the slow path from the critical region.
tensorflow::mutex_lock lock(mutex_);
if (!free_list_.empty()) {
// Fast path: reuse a free state.
std::unique_ptr<SessionState> state = std::move(free_list_.back());
free_list_.pop_back();
return state;
}
}
// Slow path: allocate a new state.
return std::unique_ptr<SessionState>(new SessionState());
}
void SessionStatePool::Release(std::unique_ptr<SessionState> state) {
{ // Exclude the slow path from the critical region.
tensorflow::mutex_lock lock(mutex_);
if (free_list_.size() < max_free_states_) {
// Fast path: reclaim in the free list.
free_list_.emplace_back(std::move(state));
return;
}
}
// Slow path: discard the excess |state| when it goes out of scope.
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SESSION_STATE_POOL_H_
#define DRAGNN_RUNTIME_SESSION_STATE_POOL_H_
#include <stddef.h>
#include <memory>
#include <utility>
#include "dragnn/runtime/session_state.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A thread-safe pool of session states that maintains a free list. The free
// list is bounded, so a spike in usage does not permanently increase the size
// of the pool. Use ScopedSessionState to interact with the pool.
class SessionStatePool {
public:
// Creates a pool whose free list holds at most |max_free_states| states.
//
// If usage spikes are not a concern (e.g., during offline processing where
// the runtime is called from a fixed-size pool of threads), then specify a
// large value like SIZE_MAX. That eliminates unnecessary deallocations and
// reallocations, and eliminates the need to coordinate the thread pool size
// with this pool's size.
//
// If memory usage dominates CPU usage, then specify 0 to eliminate overhead
// from the free list.
//
// TODO(googleuser): An alternative is to set a target allocation
// rate (e.g., 2% of Acquire()s should create a new state), and let the pool
// adapt its free list size to achieve that rate.
explicit SessionStatePool(size_t max_free_states);
private:
friend class ScopedSessionState;
// Returns a state acquired from this pool. The caller is the exclusive user
// of the returned state until it is passed to Release().
std::unique_ptr<SessionState> Acquire();
// Releases the |state| back to this pool. The |state| must be the result of
// a previous Acquire(). The caller can no longer use the |state|.
void Release(std::unique_ptr<SessionState> state);
// Maximum number of states to keep in the |free_list_|.
const size_t max_free_states_;
// Mutex guarding the |free_list_|.
tensorflow::mutex mutex_;
// List of previously-Release()d states.
std::vector<std::unique_ptr<SessionState>> free_list_ GUARDED_BY(mutex_);
};
// RAII wrapper that manages a session state acquired from a pool. The wrapped
// state is usable during the lifetime of the wrapper.
class ScopedSessionState {
public:
// Implements RAII semantics.
explicit ScopedSessionState(SessionStatePool *pool)
: pool_(pool), state_(pool_->Acquire()) {}
~ScopedSessionState() { pool_->Release(std::move(state_)); }
// Prevents double-release.
ScopedSessionState(const ScopedSessionState &that) = delete;
ScopedSessionState &operator=(const ScopedSessionState &that) = delete;
// Provides std::unique_ptr-like access.
SessionState *get() const { return state_.get(); }
SessionState &operator*() const { return *get(); }
SessionState *operator->() const { return get(); }
private:
// Pool from which the |state_| was acquired.
SessionStatePool *const pool_;
// Wrapped session state.
std::unique_ptr<SessionState> state_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_SESSION_STATE_POOL_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/session_state_pool.h"
#include <stddef.h>
#include <set>
#include "dragnn/runtime/session_state.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Maximum number of free states.
static constexpr size_t kMaxFreeStates = 16;
class SessionStatePoolTest : public ::testing::Test {
protected:
SessionStatePool pool_{kMaxFreeStates};
};
// Tests that ScopedSessionState can be used to acquire a valid state.
TEST_F(SessionStatePoolTest, ScopedWrapper) {
const ScopedSessionState state(&pool_);
EXPECT_TRUE(state.get()); // non-null
}
// Tests that the active states claimed from the pool are unique.
TEST_F(SessionStatePoolTest, UniqueActiveStates) {
// NB: Don't use std::unique_ptr<ScopedSessionState> in real code. The test
// does this because it's otherwise difficult to acquire lots of states.
std::vector<std::unique_ptr<ScopedSessionState>> states;
for (size_t i = 0; i < 100; ++i) {
states.emplace_back(new ScopedSessionState(&pool_));
}
// Check that all of the states are unique.
std::set<const SessionState *> state_ptrs;
for (const auto &state : states) {
EXPECT_TRUE(state_ptrs.insert(state->get()).second);
}
EXPECT_TRUE(state_ptrs.find(nullptr) == state_ptrs.end());
}
// Tests that active states, when released, are reclaimed and reused.
TEST_F(SessionStatePoolTest, Reuse) {
std::set<const SessionState *> state_ptrs;
{ // Grab exactly as many states as the free list can hold.
std::vector<std::unique_ptr<ScopedSessionState>> states;
for (size_t i = 0; i < kMaxFreeStates; ++i) {
states.emplace_back(new ScopedSessionState(&pool_));
EXPECT_TRUE(state_ptrs.insert(states.back()->get()).second);
}
}
{ // Grab the same number of states again and check that they are the same
// objects we saw in the first loop.
std::vector<std::unique_ptr<ScopedSessionState>> states;
for (size_t i = 0; i < kMaxFreeStates; ++i) {
states.emplace_back(new ScopedSessionState(&pool_));
EXPECT_FALSE(state_ptrs.insert(states.back()->get()).second);
}
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns true if the |component_type| can be transformed by this.
bool ShouldTransform(const string &component_type) {
for (const char *supported_type : {
"SyntaxNetHeadSelectionComponent", //
"SyntaxNetMstSolverComponent", //
}) {
if (component_type == supported_type) return true;
}
return false;
}
// Changes the backend for some components to StatelessComponent.
class StatelessComponentTransformer : public ComponentTransformer {
public:
// Implements ComponentTransformer.
tensorflow::Status Transform(const string &component_type,
ComponentSpec *component_spec) override {
if (ShouldTransform(component_type)) {
component_spec->mutable_backend()->set_registered_name(
"StatelessComponent");
}
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(StatelessComponentTransformer);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Arbitrary supported component type.
constexpr char kSupportedComponentType[] = "SyntaxNetHeadSelectionComponent";
// Returns a ComponentSpec that is supported by the transformer.
ComponentSpec MakeSupportedSpec() {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name(
kSupportedComponentType);
return component_spec;
}
// Tests that a compatible spec is modified to use StatelessComponent.
TEST(StatelessComponentTransformerTest, Compatible) {
ComponentSpec component_spec = MakeSupportedSpec();
ComponentSpec expected_spec = component_spec;
expected_spec.mutable_backend()->set_registered_name("StatelessComponent");
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(expected_spec));
}
// Tests that other component specs are not modified.
TEST(StatelessComponentTransformerTest, Incompatible) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_component_builder()->set_registered_name("other");
const ComponentSpec expected_spec = component_spec;
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(expected_spec));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/term_map_sequence_extractor.h"
#include "dragnn/runtime/term_map_utils.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "dragnn/runtime/unicode_dictionary.h"
#include "syntaxnet/base.h"
#include "syntaxnet/segmenter_utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "util/utf8/unicodetext.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Sequence extractor that extracts characters from a SyntaxNetComponent batch.
class SyntaxNetCharacterSequenceExtractor
: public TermMapSequenceExtractor<UnicodeDictionary> {
public:
SyntaxNetCharacterSequenceExtractor();
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) const override;
tensorflow::Status Initialize(const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) override;
tensorflow::Status GetIds(InputBatchCache *input,
std::vector<int32> *ids) const override;
private:
// Parses |fml| and sets |min_frequency| and |max_num_terms| to the specified
// values. If the |fml| does not specify a supported feature, returns non-OK
// and modifies nothing.
static tensorflow::Status ParseFml(const string &fml, int *min_frequency,
int *max_num_terms);
// Feature IDs for break characters and unknown characters.
int32 break_id_ = -1;
int32 unknown_id_ = -1;
};
SyntaxNetCharacterSequenceExtractor::SyntaxNetCharacterSequenceExtractor()
: TermMapSequenceExtractor("char-map") {}
tensorflow::Status SyntaxNetCharacterSequenceExtractor::ParseFml(
const string &fml, int *min_frequency, int *max_num_terms) {
return ParseTermMapFml(fml, {"char-input", "text-char"}, min_frequency,
max_num_terms);
}
bool SyntaxNetCharacterSequenceExtractor::Supports(
const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) const {
TransitionSystemTraits traits(component_spec);
int unused_min_frequency = 0;
int unused_max_num_terms = 0;
const tensorflow::Status parse_fml_status =
ParseFml(channel.fml(), &unused_min_frequency, &unused_max_num_terms);
return TermMapSequenceExtractor::SupportsTermMap(channel, component_spec) &&
parse_fml_status.ok() &&
component_spec.backend().registered_name() == "SyntaxNetComponent" &&
traits.is_sequential && traits.is_character_scale;
}
tensorflow::Status SyntaxNetCharacterSequenceExtractor::Initialize(
const FixedFeatureChannel &channel, const ComponentSpec &component_spec) {
int min_frequency = 0;
int max_num_terms = 0;
TF_RETURN_IF_ERROR(ParseFml(channel.fml(), &min_frequency, &max_num_terms));
TF_RETURN_IF_ERROR(TermMapSequenceExtractor::InitializeTermMap(
channel, component_spec, min_frequency, max_num_terms));
const int num_known = term_map().size();
break_id_ = num_known;
unknown_id_ = break_id_ + 1;
const int map_vocab_size = unknown_id_ + 1;
const int spec_vocab_size = channel.vocabulary_size();
if (map_vocab_size != spec_vocab_size) {
return tensorflow::errors::InvalidArgument(
"Character vocabulary size mismatch between term map (", map_vocab_size,
") and ComponentSpec (", spec_vocab_size, ")");
}
return tensorflow::Status::OK();
}
tensorflow::Status SyntaxNetCharacterSequenceExtractor::GetIds(
InputBatchCache *input, std::vector<int32> *ids) const {
ids->clear();
const std::vector<SyntaxNetSentence> &data =
*input->GetAs<SentenceInputBatch>()->data();
if (data.size() != 1) {
return tensorflow::errors::InvalidArgument("Non-singleton batch: got ",
data.size(), " elements");
}
const Sentence &sentence = *data[0].sentence();
if (sentence.token_size() == 0) return tensorflow::Status::OK();
const string &text = sentence.text();
const int start_byte = sentence.token(0).start();
const int end_byte = sentence.token(sentence.token_size() - 1).end();
const int num_bytes = end_byte - start_byte + 1;
string character;
UnicodeText unicode_text;
unicode_text.PointToUTF8(text.data() + start_byte, num_bytes);
const auto end = unicode_text.end();
for (auto it = unicode_text.begin(); it != end; ++it) {
character.assign(it.utf8_data(), it.utf8_length());
if (SegmenterUtils::IsBreakChar(character)) {
ids->push_back(break_id_);
} else {
ids->push_back(
term_map().Lookup(character.data(), character.size(), unknown_id_));
}
}
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(SyntaxNetCharacterSequenceExtractor);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr char kResourceName[] = "char-map";
// Returns a ComponentSpec parsed from the |text| that contains a term map
// resource pointing at the |path|.
ComponentSpec MakeSpec(const string &text, const string &path) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(text, &component_spec));
AddTermMapResource(kResourceName, path, &component_spec);
return component_spec;
}
// Returns a supported ComponentSpec that points at the term map in the |path|.
ComponentSpec MakeSupportedSpec(const string &path = "/dev/null") {
return MakeSpec(R"(transition_system { registered_name: 'char-shift-only' }
backend { registered_name: 'SyntaxNetComponent' }
fixed_feature {} # breaks hard-coded refs to channel 0
fixed_feature { size: 1 fml: 'char-input.text-char' })",
path);
}
// Returns a default sentence.
Sentence MakeSentence() {
Sentence sentence;
sentence.set_text("a bc def");
Token *token = sentence.add_token();
token->set_start(0);
token->set_end(sentence.text().size() - 1);
token->set_word(sentence.text());
return sentence;
}
// Tests that the extractor supports an appropriate spec.
TEST(SyntaxNetCharacterSequenceExtractorTest, Supported) {
const ComponentSpec component_spec = MakeSupportedSpec();
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
string name;
TF_ASSERT_OK(SequenceExtractor::Select(channel, component_spec, &name));
EXPECT_EQ(name, "SyntaxNetCharacterSequenceExtractor");
}
// Tests that the extractor requires the proper backend.
TEST(SyntaxNetCharacterSequenceExtractorTest, WrongBackend) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_backend()->set_registered_name("bad");
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
string name;
EXPECT_THAT(
SequenceExtractor::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceExtractor supports channel"));
}
// Tests that the extractor requires the proper transition system.
TEST(SyntaxNetCharacterSequenceExtractorTest, WrongTransitionSystem) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_transition_system()->set_registered_name("bad");
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
string name;
EXPECT_THAT(
SequenceExtractor::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceExtractor supports channel"));
}
// Tests that the extractor requires the proper FML.
TEST(SyntaxNetCharacterSequenceExtractorTest, WrongFml) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_fixed_feature(1)->set_fml("bad");
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
string name;
EXPECT_THAT(
SequenceExtractor::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceExtractor supports channel"));
}
// Tests that the extractor can be initialized and used to extract feature IDs.
TEST(SyntaxNetCharacterSequenceExtractorTest, InitializeAndGetIds) {
// Terms are sorted by descending frequency, so this ensures a=0, b=1, etc.
const string path =
WriteTermMap({{"a", 5}, {"b", 4}, {"c", 3}, {"d", 2}, {"e", 1}});
ComponentSpec component_spec = MakeSupportedSpec(path);
FixedFeatureChannel &channel = *component_spec.mutable_fixed_feature(1);
channel.set_vocabulary_size(7);
std::unique_ptr<SequenceExtractor> extractor;
TF_ASSERT_OK(SequenceExtractor::New("SyntaxNetCharacterSequenceExtractor",
channel, component_spec, &extractor));
const Sentence sentence = MakeSentence();
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> ids;
TF_ASSERT_OK(extractor->GetIds(&input, &ids));
// 0-4 = 'a' to 'e'
// 5 = break chars (whitespace)
// 6 = unknown chars (e.g., 'f')
const std::vector<int32> expected_ids = {0, 5, 1, 2, 5, 3, 4, 6};
EXPECT_EQ(ids, expected_ids);
}
// Tests that an empty term map works.
TEST(SyntaxNetCharacterSequenceExtractorTest, EmptyTermMap) {
const string path = WriteTermMap({});
ComponentSpec component_spec = MakeSupportedSpec(path);
FixedFeatureChannel &channel = *component_spec.mutable_fixed_feature(1);
channel.set_vocabulary_size(2);
std::unique_ptr<SequenceExtractor> extractor;
TF_ASSERT_OK(SequenceExtractor::New("SyntaxNetCharacterSequenceExtractor",
channel, component_spec, &extractor));
const Sentence sentence = MakeSentence();
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> ids = {1, 2, 3, 4}; // should be overwritten
TF_ASSERT_OK(extractor->GetIds(&input, &ids));
const std::vector<int32> expected_ids = {1, 0, 1, 1, 0, 1, 1, 1};
EXPECT_EQ(ids, expected_ids);
}
// Tests that GetIds() fails if the batch is the wrong size.
TEST(SyntaxNetCharacterSequenceExtractorTest, WrongBatchSize) {
const string path = WriteTermMap({});
ComponentSpec component_spec = MakeSupportedSpec(path);
FixedFeatureChannel &channel = *component_spec.mutable_fixed_feature(1);
channel.set_vocabulary_size(2);
std::unique_ptr<SequenceExtractor> extractor;
TF_ASSERT_OK(SequenceExtractor::New("SyntaxNetCharacterSequenceExtractor",
channel, component_spec, &extractor));
const Sentence sentence = MakeSentence();
const std::vector<string> data = {sentence.SerializeAsString(),
sentence.SerializeAsString()};
InputBatchCache input(data);
std::vector<int32> ids;
EXPECT_THAT(extractor->GetIds(&input, &ids),
test::IsErrorWithSubstr("Non-singleton batch: got 2 elements"));
}
// Tests that initialization fails if the vocabulary size does not match.
TEST(SyntaxNetCharacterSequenceExtractorTest, WrongVocabularySize) {
const string path = WriteTermMap({});
ComponentSpec component_spec = MakeSupportedSpec(path);
FixedFeatureChannel &channel = *component_spec.mutable_fixed_feature(1);
channel.set_vocabulary_size(1000);
std::unique_ptr<SequenceExtractor> extractor;
EXPECT_THAT(
SequenceExtractor::New("SyntaxNetCharacterSequenceExtractor",
channel, component_spec, &extractor),
test::IsErrorWithSubstr("Character vocabulary size mismatch between term "
"map (2) and ComponentSpec (1000)"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "util/utf8/unilib_utf8_utils.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Focus character to link to in each token.
enum class Focus {
kFirst, // first character in token
kLast, // last character in token
};
// Translator to apply to the linked character index.
enum class Translator {
kIdentity, // direct identity link
kReversed, // reverse-order link
};
// Returns the LinkedFeatureChannel.fml for the |focus|.
string ChannelFml(Focus focus) {
switch (focus) {
case Focus::kFirst:
return "input.first-char-focus";
case Focus::kLast:
return "input.last-char-focus";
}
}
// Returns the LinkedFeatureChannel.source_translator for the |translator|.
string ChannelTranslator(Translator translator) {
switch (translator) {
case Translator::kIdentity:
return "identity";
case Translator::kReversed:
return "reverse-char";
}
}
// Returns the |focus| byte index for the |token|. The returned index must be
// within the span of the |token|.
int32 GetFocusByte(Focus focus, const Token &token) {
switch (focus) {
case Focus::kFirst:
return token.start();
case Focus::kLast:
return token.end();
}
}
// Applies the |translator| to the character |index| w.r.t. the |last_index| and
// returns the result.
int32 Translate(Translator translator, int32 last_index, int32 index) {
switch (translator) {
case Translator::kIdentity:
return index;
case Translator::kReversed:
return last_index - index;
}
}
// Translates links from tokens in the target layer to UTF-8 characters in the
// source layer. Templated on a |focus| and |translator| (see above).
template <Focus focus, Translator translator>
class SyntaxNetCharacterSequenceLinker : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const override;
tensorflow::Status Initialize(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) override;
tensorflow::Status GetLinks(size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const override;
};
template <Focus focus, Translator translator>
bool SyntaxNetCharacterSequenceLinker<focus, translator>::Supports(
const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const {
TransitionSystemTraits traits(component_spec);
return channel.fml() == ChannelFml(focus) &&
channel.source_translator() == ChannelTranslator(translator) &&
component_spec.backend().registered_name() == "SyntaxNetComponent" &&
traits.is_sequential && traits.is_token_scale;
}
template <Focus focus, Translator translator>
tensorflow::Status
SyntaxNetCharacterSequenceLinker<focus, translator>::Initialize(
const LinkedFeatureChannel &channel, const ComponentSpec &component_spec) {
return tensorflow::Status::OK();
}
template <Focus focus, Translator translator>
tensorflow::Status
SyntaxNetCharacterSequenceLinker<focus, translator>::GetLinks(
size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const {
const std::vector<SyntaxNetSentence> &batch =
*input->GetAs<SentenceInputBatch>()->data();
if (batch.size() != 1) {
return tensorflow::errors::InvalidArgument("Non-singleton batch: got ",
batch.size(), " elements");
}
const Sentence &sentence = *batch[0].sentence();
const int32 num_tokens = sentence.token_size();
links->resize(num_tokens);
if (num_tokens == 0) return tensorflow::Status::OK();
// Given the properties selected in Supports(), the number of source steps
// must match the number of UTF-8 characters. The last character index will
// be used in Translate().
const int32 last_char_index = static_cast<int32>(source_num_steps) - 1;
// [start,end) byte range of the text spanned by the sentence tokens.
const int32 start_byte = sentence.token(0).start();
const int32 end_byte = sentence.token(num_tokens - 1).end() + 1;
const char *const data = sentence.text().data();
if (UniLib::IsTrailByte(data[start_byte])) {
return tensorflow::errors::InvalidArgument(
"First token starts in the middle of a UTF-8 character: ",
sentence.token(0).ShortDebugString());
}
// Current character index and its past-the-end byte in the sentence.
int32 char_index = 0;
int32 char_end_byte = start_byte + UniLib::OneCharLen(data + start_byte);
// Current token index and its byte index.
int32 token_index = 0;
int32 token_byte = GetFocusByte(focus, sentence.token(0));
// Scan through the characters and tokens. For each token, we assign it the
// character whose byte range contains its focus byte.
while (true) {
// If the character ends after the token, then the token must lie within the
// character, or we would have consumed the token in a previous iteration.
if (char_end_byte > token_byte) {
(*links)[token_index] =
Translate(translator, last_char_index, char_index);
if (++token_index >= num_tokens) break;
token_byte = GetFocusByte(focus, sentence.token(token_index));
} else if (char_end_byte < end_byte) {
++char_index;
char_end_byte += UniLib::OneCharLen(data + char_end_byte);
} else {
break;
}
}
if (char_end_byte > end_byte) {
return tensorflow::errors::InvalidArgument(
"Last token ends in the middle of a UTF-8 character: ",
sentence.token(num_tokens - 1).ShortDebugString());
}
// Since GetFocusByte() always returns a byte index within the span of the
// token, the loop above must consume all tokens.
DCHECK_EQ(token_index, num_tokens);
return tensorflow::Status::OK();
}
using SyntaxNetFirstCharacterIdentitySequenceLinker =
SyntaxNetCharacterSequenceLinker<Focus::kFirst, Translator::kIdentity>;
using SyntaxNetFirstCharacterReversedSequenceLinker =
SyntaxNetCharacterSequenceLinker<Focus::kFirst, Translator::kReversed>;
using SyntaxNetLastCharacterIdentitySequenceLinker =
SyntaxNetCharacterSequenceLinker<Focus::kLast, Translator::kIdentity>;
using SyntaxNetLastCharacterReversedSequenceLinker =
SyntaxNetCharacterSequenceLinker<Focus::kLast, Translator::kReversed>;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(
SyntaxNetFirstCharacterIdentitySequenceLinker);
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(
SyntaxNetFirstCharacterReversedSequenceLinker);
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(
SyntaxNetLastCharacterIdentitySequenceLinker);
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(
SyntaxNetLastCharacterReversedSequenceLinker);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::ElementsAre;
// Returns a ComponentSpec parsed from the |text|.
ComponentSpec ParseSpec(const string &text) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(text, &component_spec));
return component_spec;
}
// Returns a ComponentSpec that some linker supports.
ComponentSpec MakeSupportedSpec() {
return ParseSpec(R"(
transition_system { registered_name:'shift-only' }
backend { registered_name:'SyntaxNetComponent' }
linked_feature { fml:'input.first-char-focus' source_translator:'identity' }
)");
}
// Returns a Sentence parsed from the |text|.
Sentence ParseSentence(const string &text) {
Sentence sentence;
CHECK(TextFormat::ParseFromString(text, &sentence));
return sentence;
}
// Returns a default sentence.
Sentence MakeSentence() {
return ParseSentence(R"(
text:'012345678901234567890123456789人1工神2经网¢络'
token { start:30 end:36 word:'人1工' }
token { start:37 end:43 word:'神2经' }
token { start:44 end:51 word:'网¢络' }
)");
}
// Number of UTF-8 characters in the default sentence.
constexpr int kNumChars = 9;
// Tests that the linker supports appropriate specs.
TEST(SyntaxNetCharacterSequenceLinkersTest, Supported) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "SyntaxNetFirstCharacterIdentitySequenceLinker");
channel.set_source_translator("reverse-char");
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "SyntaxNetFirstCharacterReversedSequenceLinker");
channel.set_fml("input.last-char-focus");
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "SyntaxNetLastCharacterReversedSequenceLinker");
channel.set_source_translator("identity");
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "SyntaxNetLastCharacterIdentitySequenceLinker");
}
// Tests that the linker requires the right transition system.
TEST(SyntaxNetCharacterSequenceLinkersTest, WrongTransitionSystem) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
component_spec.mutable_backend()->set_registered_name("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right FML.
TEST(SyntaxNetCharacterSequenceLinkersTest, WrongFml) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_fml("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right translator.
TEST(SyntaxNetCharacterSequenceLinkersTest, WrongTranslator) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_source_translator("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right backend.
TEST(SyntaxNetCharacterSequenceLinkersTest, WrongBackend) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
component_spec.mutable_backend()->set_registered_name("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Rig for testing GetLinks().
class SyntaxNetCharacterSequenceLinkersGetLinksTest : public ::testing::Test {
protected:
void SetUp() override {
// Initialize() doesn't look at the channel or spec, so use empty protos.
const ComponentSpec component_spec;
const LinkedFeatureChannel channel;
TF_ASSERT_OK(
SequenceLinker::New("SyntaxNetFirstCharacterIdentitySequenceLinker",
channel, component_spec, &first_identity_));
TF_ASSERT_OK(
SequenceLinker::New("SyntaxNetFirstCharacterReversedSequenceLinker",
channel, component_spec, &first_reversed_));
TF_ASSERT_OK(
SequenceLinker::New("SyntaxNetLastCharacterIdentitySequenceLinker",
channel, component_spec, &last_identity_));
TF_ASSERT_OK(
SequenceLinker::New("SyntaxNetLastCharacterReversedSequenceLinker",
channel, component_spec, &last_reversed_));
}
// Linkers in all four configurations.
std::unique_ptr<SequenceLinker> first_identity_;
std::unique_ptr<SequenceLinker> first_reversed_;
std::unique_ptr<SequenceLinker> last_identity_;
std::unique_ptr<SequenceLinker> last_reversed_;
};
// Tests that the linkers can extract links from the default sentence.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest, DefaultSentence) {
const Sentence sentence = MakeSentence();
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> links = {123, 456, 789}; // gets overwritten
TF_ASSERT_OK(first_identity_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(0, 3, 6));
TF_ASSERT_OK(first_reversed_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(8, 5, 2));
TF_ASSERT_OK(last_identity_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(2, 5, 8));
TF_ASSERT_OK(last_reversed_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(6, 3, 0));
}
// Tests that the linkers can handle an empty sentence.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest, EmptySentence) {
const Sentence sentence;
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> links;
TF_ASSERT_OK(first_identity_->GetLinks(kNumChars, &input, &links));
TF_ASSERT_OK(first_reversed_->GetLinks(kNumChars, &input, &links));
TF_ASSERT_OK(last_identity_->GetLinks(kNumChars, &input, &links));
TF_ASSERT_OK(last_reversed_->GetLinks(kNumChars, &input, &links));
}
// Tests that the linkers fail if the batch is not a singleton.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest, NonSingleton) {
const Sentence sentence = MakeSentence();
const std::vector<string> data = {sentence.SerializeAsString(),
sentence.SerializeAsString()};
InputBatchCache input(data);
std::vector<int32> links;
EXPECT_THAT(first_identity_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr("Non-singleton batch: got 2 elements"));
EXPECT_THAT(first_reversed_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr("Non-singleton batch: got 2 elements"));
EXPECT_THAT(last_identity_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr("Non-singleton batch: got 2 elements"));
EXPECT_THAT(last_reversed_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr("Non-singleton batch: got 2 elements"));
}
// Tests that the linkers fail if the first token starts in the middle of a
// UTF-8 character.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest, FirstTokenStartsWrong) {
Sentence sentence = MakeSentence();
sentence.mutable_token(0)->set_start(31);
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> links;
EXPECT_THAT(first_identity_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"First token starts in the middle of a UTF-8 character"));
EXPECT_THAT(first_reversed_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"First token starts in the middle of a UTF-8 character"));
EXPECT_THAT(last_identity_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"First token starts in the middle of a UTF-8 character"));
EXPECT_THAT(last_reversed_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"First token starts in the middle of a UTF-8 character"));
}
// Tests that the linkers fail if the last token ends in the middle of a UTF-8
// character.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest, LastTokenEndsWrong) {
Sentence sentence = MakeSentence();
sentence.mutable_token(2)->set_end(45);
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> links;
EXPECT_THAT(first_identity_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"Last token ends in the middle of a UTF-8 character"));
EXPECT_THAT(first_reversed_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"Last token ends in the middle of a UTF-8 character"));
EXPECT_THAT(last_identity_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"Last token ends in the middle of a UTF-8 character"));
EXPECT_THAT(last_reversed_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"Last token ends in the middle of a UTF-8 character"));
}
// Tests that the linkers can tolerate a sentence where the interior token byte
// offsets are wrong.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest,
InteriorTokenBoundariesSlightlyWrong) {
Sentence sentence = MakeSentence();
sentence.mutable_token(0)->set_end(35);
sentence.mutable_token(1)->set_start(38);
sentence.mutable_token(1)->set_end(42);
sentence.mutable_token(2)->set_start(45);
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> links;
// The results should be the same as in the default sentence.
TF_ASSERT_OK(first_identity_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(0, 3, 6));
TF_ASSERT_OK(first_reversed_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(8, 5, 2));
TF_ASSERT_OK(last_identity_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(2, 5, 8));
TF_ASSERT_OK(last_reversed_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(6, 3, 0));
}
// As above, but places the token boundaries even further off.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest,
InteriorTokenBoundariesMostlyWrong) {
Sentence sentence = MakeSentence();
sentence.mutable_token(0)->set_end(34);
sentence.mutable_token(1)->set_start(39);
sentence.mutable_token(1)->set_end(41);
sentence.mutable_token(2)->set_start(46);
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> links;
// The results should be the same as in the default sentence.
TF_ASSERT_OK(first_identity_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(0, 3, 6));
TF_ASSERT_OK(first_reversed_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(8, 5, 2));
TF_ASSERT_OK(last_identity_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(2, 5, 8));
TF_ASSERT_OK(last_reversed_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(6, 3, 0));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/head_selection_component_base.h"
#include "dragnn/runtime/session_state.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Selects heads for SyntaxNetComponent batches.
class SyntaxNetHeadSelectionComponent : public HeadSelectionComponentBase {
public:
SyntaxNetHeadSelectionComponent()
: HeadSelectionComponentBase("SyntaxNetHeadSelectionComponent",
"SyntaxNetComponent") {}
// Implements Component.
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override;
};
tensorflow::Status SyntaxNetHeadSelectionComponent::Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const {
InputBatchCache *input = compute_session->GetInputBatchCache();
if (input == nullptr) {
return tensorflow::errors::InvalidArgument("Null input batch");
}
const std::vector<SyntaxNetSentence> &data =
*input->GetAs<SentenceInputBatch>()->data();
if (data.size() != 1) {
return tensorflow::errors::InvalidArgument("Non-singleton batch: got ",
data.size(), " elements");
}
const std::vector<int> &heads = ComputeHeads(session_state);
Sentence *sentence = data[0].sentence();
if (heads.size() != sentence->token_size()) {
return tensorflow::errors::InvalidArgument(
"Sentence size mismatch: expected ", heads.size(), " tokens but got ",
sentence->token_size());
}
int token_index = 0;
for (const int head : heads) {
Token *token = sentence->mutable_token(token_index++);
if (head == -1) {
token->clear_head();
} else {
token->set_head(head);
}
}
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT(SyntaxNetHeadSelectionComponent);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/sentence.pb.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::Return;
constexpr char kPreviousComponentName[] = "previous_component";
constexpr char kAdjacencyLayerName[] = "adjacency_layer";
// Returns a ComponentSpec that works with the head selection component.
ComponentSpec MakeGoodSpec() {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name(
"SyntaxNetHeadSelectionComponent");
component_spec.mutable_backend()->set_registered_name("SyntaxNetComponent");
component_spec.mutable_transition_system()->set_registered_name("heads");
component_spec.mutable_network_unit()->set_registered_name("IdentityNetwork");
LinkedFeatureChannel *link = component_spec.add_linked_feature();
link->set_source_component(kPreviousComponentName);
link->set_source_layer(kAdjacencyLayerName);
return component_spec;
}
// Returns a sentence containing |num_tokens| tokens. All heads are set to
// self-loops, which are normally invalid, to ensure that the head selector
// touches all tokens.
Sentence MakeSentence(int num_tokens) {
Sentence sentence;
for (int i = 0; i < num_tokens; ++i) {
Token *token = sentence.add_token();
token->set_start(0); // never used; set because required field
token->set_end(0); // never used; set because required field
token->set_word("foo"); // never used; set because required field
token->set_head(i);
}
return sentence;
}
class SyntaxNetHeadSelectionComponentTest : public NetworkTestBase {
protected:
// Initializes a parser head selection component from the |component_spec|,
// feeds it the |adjacency| matrix, and applies the resulting heads to the
// |sentence|. Returs non-OK on error.
tensorflow::Status Run(const ComponentSpec &component_spec,
const std::vector<std::vector<float>> &adjacency,
Sentence *sentence) {
AddComponent(kPreviousComponentName);
AddPairwiseLayer(kAdjacencyLayerName, 1);
std::unique_ptr<Component> component;
TF_RETURN_IF_ERROR(Component::CreateOrError(
"SyntaxNetHeadSelectionComponent", &component));
TF_RETURN_IF_ERROR(component->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
network_states_.Reset(&network_state_manager_);
const int num_steps = adjacency.size();
StartComponent(num_steps);
MutableMatrix<float> adjacency_layer =
GetPairwiseLayer(kPreviousComponentName, kAdjacencyLayerName);
for (size_t target = 0; target < num_steps; ++target) {
for (size_t source = 0; source < num_steps; ++source) {
adjacency_layer.row(target)[source] = adjacency[target][source];
}
}
string data;
CHECK(sentence->SerializeToString(&data));
InputBatchCache input(data);
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input));
session_state_.extensions.Reset(&extension_manager_);
TF_RETURN_IF_ERROR(
component->Evaluate(&session_state_, &compute_session_, nullptr));
CHECK(sentence->ParseFromString(input.SerializedData()[0]));
return tensorflow::Status::OK();
}
};
// Tests the head selector on a single-token input.
TEST_F(SyntaxNetHeadSelectionComponentTest, ParseOneToken) {
const std::vector<std::vector<float>> adjacency = {{0.0}};
Sentence sentence = MakeSentence(1);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_FALSE(sentence.token(0).has_head());
}
// Tests the head selector on a two-token input.
TEST_F(SyntaxNetHeadSelectionComponentTest, ParseTwoTokens) {
// This adjacency matrix forms a cycle, not a tree, but it doesn't matter
// since the head selector is unstructured.
const std::vector<std::vector<float>> adjacency = {{0.0, 1.0}, //
{1.0, 0.0}};
Sentence sentence = MakeSentence(2);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_EQ(sentence.token(0).head(), 1);
EXPECT_EQ(sentence.token(1).head(), 0);
}
// Tests the head selector on a three-token input.
TEST_F(SyntaxNetHeadSelectionComponentTest, ParseThreeTokens) {
// This adjacency matrix forms a left-headed chain.
const std::vector<std::vector<float>> adjacency = {{1.0, 0.0, 0.0}, //
{1.0, 0.0, 0.0}, //
{0.0, 1.0, 0.0}};
Sentence sentence = MakeSentence(3);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_FALSE(sentence.token(0).has_head());
EXPECT_EQ(sentence.token(1).head(), 0);
EXPECT_EQ(sentence.token(2).head(), 1);
}
// Tests the head selector on a four-token input.
TEST_F(SyntaxNetHeadSelectionComponentTest, ParseFourTokens) {
// This adjacency matrix forms a right-headed chain.
const std::vector<std::vector<float>> adjacency = {{0.0, 1.0, 0.0, 0.0}, //
{0.0, 0.0, 1.0, 0.0}, //
{0.0, 0.0, 0.0, 1.0}, //
{0.0, 0.0, 0.0, 1.0}};
Sentence sentence = MakeSentence(4);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_EQ(sentence.token(0).head(), 1);
EXPECT_EQ(sentence.token(1).head(), 2);
EXPECT_EQ(sentence.token(2).head(), 3);
EXPECT_FALSE(sentence.token(3).has_head());
}
// Tests that the component supports the good spec.
TEST_F(SyntaxNetHeadSelectionComponentTest, Supported) {
const ComponentSpec component_spec = MakeGoodSpec();
string name;
TF_ASSERT_OK(Component::Select(component_spec, &name));
EXPECT_EQ(name, "SyntaxNetHeadSelectionComponent");
}
// Tests that the component requires the proper backend.
TEST_F(SyntaxNetHeadSelectionComponentTest, WrongComponentBuilder) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_component_builder()->set_registered_name("bad");
string name;
EXPECT_THAT(
Component::Select(component_spec, &name),
test::IsErrorWithSubstr("Could not find a best spec for component"));
}
// Tests that the component requires the proper backend.
TEST_F(SyntaxNetHeadSelectionComponentTest, WrongBackend) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_backend()->set_registered_name("bad");
string name;
EXPECT_THAT(
Component::Select(component_spec, &name),
test::IsErrorWithSubstr("Could not find a best spec for component"));
}
// Tests that Evaluate() fails if the batch is null.
TEST_F(SyntaxNetHeadSelectionComponentTest, NullBatch) {
std::unique_ptr<Component> component;
TF_ASSERT_OK(
Component::CreateOrError("SyntaxNetHeadSelectionComponent", &component));
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(nullptr));
EXPECT_THAT(component->Evaluate(&session_state_, &compute_session_, nullptr),
test::IsErrorWithSubstr("Null input batch"));
}
// Tests that Evaluate() fails if the batch is the wrong size.
TEST_F(SyntaxNetHeadSelectionComponentTest, WrongBatchSize) {
std::unique_ptr<Component> component;
TF_ASSERT_OK(
Component::CreateOrError("SyntaxNetHeadSelectionComponent", &component));
InputBatchCache input({MakeSentence(1).SerializeAsString(),
MakeSentence(2).SerializeAsString(),
MakeSentence(3).SerializeAsString(),
MakeSentence(4).SerializeAsString()});
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input));
EXPECT_THAT(component->Evaluate(&session_state_, &compute_session_, nullptr),
test::IsErrorWithSubstr("Non-singleton batch: got 4 elements"));
}
// Tests that Evaluate() fails if the adjacency matrix and sentence disagree on
// the number of tokens.
TEST_F(SyntaxNetHeadSelectionComponentTest, WrongNumTokens) {
const std::vector<std::vector<float>> adjacency = {{1.0, 0.0, 0.0, 0.0}, //
{0.0, 1.0, 0.0, 0.0}, //
{0.0, 0.0, 1.0, 0.0}, //
{0.0, 0.0, 0.0, 1.0}};
// 4-token adjacency matrix with 3-token sentence.
Sentence sentence = MakeSentence(3);
EXPECT_THAT(Run(MakeGoodSpec(), adjacency, &sentence),
test::IsErrorWithSubstr(
"Sentence size mismatch: expected 4 tokens but got 3"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/mst_solver_component_base.h"
#include "dragnn/runtime/session_state.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Selects heads for SyntaxNetComponent batches.
class SyntaxNetMstSolverComponent : public MstSolverComponentBase {
public:
SyntaxNetMstSolverComponent()
: MstSolverComponentBase("SyntaxNetMstSolverComponent",
"SyntaxNetComponent") {}
// Implements Component.
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override;
};
tensorflow::Status SyntaxNetMstSolverComponent::Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const {
InputBatchCache *input = compute_session->GetInputBatchCache();
if (input == nullptr) {
return tensorflow::errors::InvalidArgument("Null input batch");
}
const std::vector<SyntaxNetSentence> &data =
*input->GetAs<SentenceInputBatch>()->data();
if (data.size() != 1) {
return tensorflow::errors::InvalidArgument("Non-singleton batch: got ",
data.size(), " elements");
}
tensorflow::gtl::ArraySlice<Index> heads;
TF_RETURN_IF_ERROR(ComputeHeads(session_state, &heads));
Sentence *sentence = data[0].sentence();
if (heads.size() != sentence->token_size()) {
return tensorflow::errors::InvalidArgument(
"Sentence size mismatch: expected ", heads.size(), " tokens but got ",
sentence->token_size());
}
const int num_tokens = heads.size();
for (int modifier = 0; modifier < num_tokens; ++modifier) {
Token *token = sentence->mutable_token(modifier);
const int head = heads[modifier];
if (head == modifier) {
token->clear_head();
} else {
token->set_head(head);
}
}
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT(SyntaxNetMstSolverComponent);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/sentence.pb.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::Return;
constexpr char kPreviousComponentName[] = "previous_component";
constexpr char kAdjacencyLayerName[] = "adjacency_layer";
// Returns a ComponentSpec that works with the head selection component.
ComponentSpec MakeGoodSpec() {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name(
"SyntaxNetMstSolverComponent");
component_spec.mutable_backend()->set_registered_name("SyntaxNetComponent");
component_spec.mutable_transition_system()->set_registered_name("heads");
component_spec.mutable_network_unit()->set_registered_name(
"some.path.to.MstSolverNetwork");
LinkedFeatureChannel *link = component_spec.add_linked_feature();
link->set_source_component(kPreviousComponentName);
link->set_source_layer(kAdjacencyLayerName);
return component_spec;
}
// Returns a sentence containing |num_tokens| tokens. All heads are set to
// self-loops, which are normally invalid, to ensure that the head selector
// touches all tokens.
Sentence MakeSentence(int num_tokens) {
Sentence sentence;
for (int i = 0; i < num_tokens; ++i) {
Token *token = sentence.add_token();
token->set_start(0); // never used; set because required field
token->set_end(0); // never used; set because required field
token->set_word("foo"); // never used; set because required field
token->set_head(i);
}
return sentence;
}
class SyntaxNetMstSolverComponentTest : public NetworkTestBase {
protected:
// Initializes a parser head selection component from the |component_spec|,
// feeds it the |adjacency| matrix, and applies the resulting heads to the
// |sentence|. Returs non-OK on error.
tensorflow::Status Run(const ComponentSpec &component_spec,
const std::vector<std::vector<float>> &adjacency,
Sentence *sentence) {
AddComponent(kPreviousComponentName);
AddPairwiseLayer(kAdjacencyLayerName, 1);
std::unique_ptr<Component> component;
TF_RETURN_IF_ERROR(Component::CreateOrError(
"SyntaxNetMstSolverComponent", &component));
TF_RETURN_IF_ERROR(component->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
network_states_.Reset(&network_state_manager_);
const int num_steps = adjacency.size();
StartComponent(num_steps);
MutableMatrix<float> adjacency_layer =
GetPairwiseLayer(kPreviousComponentName, kAdjacencyLayerName);
for (size_t target = 0; target < num_steps; ++target) {
for (size_t source = 0; source < num_steps; ++source) {
adjacency_layer.row(target)[source] = adjacency[target][source];
}
}
string data;
CHECK(sentence->SerializeToString(&data));
InputBatchCache input(data);
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input));
session_state_.extensions.Reset(&extension_manager_);
TF_RETURN_IF_ERROR(
component->Evaluate(&session_state_, &compute_session_, nullptr));
CHECK(sentence->ParseFromString(input.SerializedData()[0]));
return tensorflow::Status::OK();
}
};
// Tests the head selector on a single-token input.
TEST_F(SyntaxNetMstSolverComponentTest, ParseOneToken) {
const std::vector<std::vector<float>> adjacency = {{0.0}};
Sentence sentence = MakeSentence(1);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_FALSE(sentence.token(0).has_head());
}
// Tests the head selector on a two-token input.
TEST_F(SyntaxNetMstSolverComponentTest, ParseTwoTokens) {
const std::vector<std::vector<float>> adjacency = {{0.0, 1.0}, //
{0.9, 1.0}};
Sentence sentence = MakeSentence(2);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_EQ(sentence.token(0).head(), 1);
EXPECT_EQ(sentence.token(1).head(), -1);
}
// Tests the head selector on a three-token input.
TEST_F(SyntaxNetMstSolverComponentTest, ParseThreeTokens) {
// This adjacency matrix forms a left-headed chain.
const std::vector<std::vector<float>> adjacency = {{1.0, 0.0, 0.0}, //
{1.0, 0.0, 0.0}, //
{0.0, 1.0, 0.0}};
Sentence sentence = MakeSentence(3);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_FALSE(sentence.token(0).has_head());
EXPECT_EQ(sentence.token(1).head(), 0);
EXPECT_EQ(sentence.token(2).head(), 1);
}
// Tests the head selector on a four-token input.
TEST_F(SyntaxNetMstSolverComponentTest, ParseFourTokens) {
// This adjacency matrix forms a right-headed chain.
const std::vector<std::vector<float>> adjacency = {{0.0, 1.0, 0.0, 0.0}, //
{0.0, 0.0, 1.0, 0.0}, //
{0.0, 0.0, 0.0, 1.0}, //
{0.0, 0.0, 0.0, 1.0}};
Sentence sentence = MakeSentence(4);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_EQ(sentence.token(0).head(), 1);
EXPECT_EQ(sentence.token(1).head(), 2);
EXPECT_EQ(sentence.token(2).head(), 3);
EXPECT_FALSE(sentence.token(3).has_head());
}
// Tests that the component supports the good spec.
TEST_F(SyntaxNetMstSolverComponentTest, Supported) {
const ComponentSpec component_spec = MakeGoodSpec();
string name;
TF_ASSERT_OK(Component::Select(component_spec, &name));
EXPECT_EQ(name, "SyntaxNetMstSolverComponent");
}
// Tests that the component requires the proper backend.
TEST_F(SyntaxNetMstSolverComponentTest, WrongComponentBuilder) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_component_builder()->set_registered_name("bad");
string name;
EXPECT_THAT(
Component::Select(component_spec, &name),
test::IsErrorWithSubstr("Could not find a best spec for component"));
}
// Tests that the component requires the proper backend.
TEST_F(SyntaxNetMstSolverComponentTest, WrongBackend) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_backend()->set_registered_name("bad");
string name;
EXPECT_THAT(
Component::Select(component_spec, &name),
test::IsErrorWithSubstr("Could not find a best spec for component"));
}
// Tests that Evaluate() fails if the batch is null.
TEST_F(SyntaxNetMstSolverComponentTest, NullBatch) {
std::unique_ptr<Component> component;
TF_ASSERT_OK(
Component::CreateOrError("SyntaxNetMstSolverComponent", &component));
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(nullptr));
EXPECT_THAT(component->Evaluate(&session_state_, &compute_session_, nullptr),
test::IsErrorWithSubstr("Null input batch"));
}
// Tests that Evaluate() fails if the batch is the wrong size.
TEST_F(SyntaxNetMstSolverComponentTest, WrongBatchSize) {
std::unique_ptr<Component> component;
TF_ASSERT_OK(
Component::CreateOrError("SyntaxNetMstSolverComponent", &component));
InputBatchCache input({MakeSentence(1).SerializeAsString(),
MakeSentence(2).SerializeAsString(),
MakeSentence(3).SerializeAsString(),
MakeSentence(4).SerializeAsString()});
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input));
EXPECT_THAT(component->Evaluate(&session_state_, &compute_session_, nullptr),
test::IsErrorWithSubstr("Non-singleton batch: got 4 elements"));
}
// Tests that Evaluate() fails if the adjacency matrix and sentence disagree on
// the number of tokens.
TEST_F(SyntaxNetMstSolverComponentTest, WrongNumTokens) {
const std::vector<std::vector<float>> adjacency = {{1.0, 0.0, 0.0, 0.0}, //
{0.0, 1.0, 0.0, 0.0}, //
{0.0, 0.0, 1.0, 0.0}, //
{0.0, 0.0, 0.0, 1.0}};
// 4-token adjacency matrix with 3-token sentence.
Sentence sentence = MakeSentence(3);
EXPECT_THAT(Run(MakeGoodSpec(), adjacency, &sentence),
test::IsErrorWithSubstr(
"Sentence size mismatch: expected 4 tokens but got 3"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment