Commit a4bb31d0 authored by Terry Koo's avatar Terry Koo
Browse files

Export @195097388.

parent dea7ecf6
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit.h"
#include <vector>
#include "tensorflow/core/lib/strings/str_util.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
string NetworkUnit::GetClassName(const ComponentSpec &component_spec) {
// The Python registration API is based on (relative) module paths, such as
// "some.module.FooNetwork". Therefore, we discard the module path prefix and
// use only the final segment, which is the subclass name.
const std::vector<string> segments = tensorflow::str_util::Split(
component_spec.network_unit().registered_name(), ".");
CHECK_GT(segments.size(), 0) << "No network unit name for component spec: "
<< component_spec.ShortDebugString();
return segments.back();
}
} // namespace runtime
} // namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Network Unit",
dragnn::runtime::NetworkUnit);
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_NETWORK_UNIT_H_
#define DRAGNN_RUNTIME_NETWORK_UNIT_H_
#include <stddef.h>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Interface for network units for sequential inference.
class NetworkUnit : public RegisterableClass<NetworkUnit> {
public:
NetworkUnit(const NetworkUnit &that) = delete;
NetworkUnit &operator=(const NetworkUnit &that) = delete;
virtual ~NetworkUnit() = default;
// Returns the network unit class name specified in the |component_spec|.
static string GetClassName(const ComponentSpec &component_spec);
// Initializes this to the configuration in the |component_spec|. Retrieves
// pre-trained variables from the |variable_store|, which must outlive this.
// Adds layers and local operands to the |network_state_manager|, which must
// be positioned at the current component. Requests SessionState extensions
// from the |extension_manager|. On error, returns non-OK.
virtual tensorflow::Status Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) = 0;
// Returns the name of the layer that contains classification logits, or an
// empty string if this does not produce logits. Requires that Initialize()
// was called.
virtual string GetLogitsName() const = 0;
// Evaluates this network unit on the |session_state| and |compute_session|.
// Requires that:
// * The network states in the |session_state| is positioned at the current
// component, which must have at least |step_index|+1 steps.
// * The same component in the |compute_session| must have traversed
// |step_index| transitions.
// * Initialize() was called.
// On error, returns non-OK.
virtual tensorflow::Status Evaluate(
size_t step_index, SessionState *session_state,
ComputeSession *compute_session) const = 0;
protected:
NetworkUnit() = default;
private:
// Helps prevent use of the Create() method; use CreateOrError() instead.
using RegisterableClass<NetworkUnit>::Create;
};
} // namespace runtime
} // namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Network Unit",
dragnn::runtime::NetworkUnit);
} // namespace syntaxnet
// Registers a subclass using its class name as a string.
#define DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::NetworkUnit, #subclass, subclass)
#endif // DRAGNN_RUNTIME_NETWORK_UNIT_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit_base.h"
#include <string.h>
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns the sum of the dimensions of all channels in the |manager|. The
// EmbeddingManager template type should be either FixedEmbeddingManager or
// LinkedEmbeddingManager; note that both share the same API.
template <class EmbeddingManager>
size_t SumEmbeddingDimensions(const EmbeddingManager &manager) {
size_t sum = 0;
for (size_t i = 0; i < manager.num_channels(); ++i) {
sum += manager.embedding_dim(i);
}
return sum;
}
// Copies each channel of the |embeddings| into the region starting at |data|.
// Returns a pointer to one past the last element of the copied region. The
// Embeddings type should be FixedEmbeddings or LinkedEmbeddings; note that both
// have the same API.
//
// TODO(googleuser): Try a vectorized copy instead of memcpy(). Unclear whether
// we can do better, though. For one, the memcpy() implementation may already
// be vectorized. Also, while the input embeddings are aligned, the output is
// not; e.g., consider concatenating inputs with dims 7 and 9. This could be
// addressed by requiring that embedding dims are aligned, or by handling the
// unaligned prefix separately.
//
// TODO(googleuser): Consider alternatives for handling fixed feature channels
// with size>1. The least surprising approach is to concatenate the size>1
// embeddings inside FixedEmbeddings, so the channel IDs still correspond to
// positions in the ComponentSpec.fixed_feature list. However, that means the
// same embedding gets copied twice, once there and once here. Conversely, we
// could split the size>1 embeddings into separate channels, eliding a copy
// while obfuscating the channel IDs. IMO, separate channels seem better
// because very few bits of DRAGNN actually access individual channels, and I
// wrote many of those bits.
template <class Embeddings>
float *CopyEmbeddings(const Embeddings &embeddings, float *data) {
for (size_t i = 0; i < embeddings.num_embeddings(); ++i) {
const Vector<float> vector = embeddings.embedding(i);
memcpy(data, vector.data(), vector.size() * sizeof(float));
data += vector.size();
}
return data;
}
} // namespace
tensorflow::Status NetworkUnitBase::InitializeBase(
bool use_concatenated_input, const ComponentSpec &component_spec,
VariableStore *variable_store, NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
use_concatenated_input_ = use_concatenated_input;
num_actions_ = component_spec.num_actions();
TF_RETURN_IF_ERROR(fixed_embedding_manager_.Reset(
component_spec, variable_store, network_state_manager));
TF_RETURN_IF_ERROR(linked_embedding_manager_.Reset(
component_spec, variable_store, network_state_manager));
concatenated_input_dim_ = SumEmbeddingDimensions(fixed_embedding_manager_) +
SumEmbeddingDimensions(linked_embedding_manager_);
if (use_concatenated_input_) {
// If there is <= 1 input embedding, then the concatenation is trivial and
// we don't need a local vector; see ConcatenateInput().
const size_t num_embeddings = fixed_embedding_manager_.num_embeddings() +
linked_embedding_manager_.num_embeddings();
if (num_embeddings > 1) {
TF_RETURN_IF_ERROR(network_state_manager->AddLocal(
concatenated_input_dim_, &concatenated_input_handle_));
}
// Check that all fixed features are embedded.
for (size_t i = 0; i < fixed_embedding_manager_.num_channels(); ++i) {
if (!fixed_embedding_manager_.is_embedded(i)) {
return tensorflow::errors::InvalidArgument(
"Non-embedded fixed features cannot be concatenated");
}
}
}
extension_manager->GetShared(&fixed_embeddings_handle_);
extension_manager->GetShared(&linked_embeddings_handle_);
return tensorflow::Status::OK();
}
tensorflow::Status NetworkUnitBase::EvaluateBase(
SessionState *session_state, ComputeSession *compute_session,
Vector<float> *concatenated_input) const {
FixedEmbeddings &fixed_embeddings =
session_state->extensions.Get(fixed_embeddings_handle_);
LinkedEmbeddings &linked_embeddings =
session_state->extensions.Get(linked_embeddings_handle_);
TF_RETURN_IF_ERROR(fixed_embeddings.Reset(&fixed_embedding_manager_,
session_state->network_states,
compute_session));
TF_RETURN_IF_ERROR(linked_embeddings.Reset(&linked_embedding_manager_,
session_state->network_states,
compute_session));
if (use_concatenated_input_ && concatenated_input != nullptr) {
*concatenated_input = ConcatenateInput(session_state);
}
return tensorflow::Status::OK();
}
Vector<float> NetworkUnitBase::ConcatenateInput(
SessionState *session_state) const {
DCHECK(use_concatenated_input_);
const FixedEmbeddings &fixed_embeddings =
session_state->extensions.Get(fixed_embeddings_handle_);
const LinkedEmbeddings &linked_embeddings =
session_state->extensions.Get(linked_embeddings_handle_);
const size_t num_embeddings =
fixed_embeddings.num_embeddings() + linked_embeddings.num_embeddings();
// Special cases where no actual concatenation is required.
if (num_embeddings == 0) return {};
if (num_embeddings == 1) {
return fixed_embeddings.num_embeddings() > 0
? fixed_embeddings.embedding(0)
: linked_embeddings.embedding(0);
}
// General case; concatenate into a local vector. The ordering of embeddings
// must be exactly the same as in the Python codebase, which is:
// 1. Fixed embeddings before linked embeddings (see get_input_tensor() in
// network_units.py).
// 2. In each type, ordered as listed in ComponentSpec.fixed/linked_feature
// (see DynamicComponentBuilder._feedforward_unit() in component.py).
//
// Since FixedEmbeddings and LinkedEmbeddings already follow the order defined
// in the ComponentSpec, it suffices to append each fixed embedding, then each
// linked embedding.
const MutableVector<float> concatenation =
session_state->network_states.GetLocal(concatenated_input_handle_);
float *data = concatenation.data();
data = CopyEmbeddings(fixed_embeddings, data);
data = CopyEmbeddings(linked_embeddings, data);
DCHECK_EQ(data, concatenation.end());
return Vector<float>(concatenation);
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_NETWORK_UNIT_BASE_H_
#define DRAGNN_RUNTIME_NETWORK_UNIT_BASE_H_
#include <stddef.h>
#include <utility>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A base class for network units that provides common functionality, analogous
// to NetworkUnitInterface.__init__() in network_units.py. Specifically, this
// class manages and builds input embeddings and, as an convenience, optionally
// concatenates the input embeddings into a single vector.
//
// Since recurrent layers are both outputs and inputs, they complicate network
// unit initialization. In particular, the linked embeddings cannot be set up
// until the charateristics of all recurrently-accessible layers are known. On
// the other hand, some layers cannot be initialized until all inputs, including
// the linked embeddings, are set up. For example, the IdentityNetwork outputs
// a layer whose dimension is the sum of all input dimensions.
//
// To accommodate recurrent layers, network unit initialization is organized
// into three phases:
// 1. (Subclass) Initialize all recurrently-accessible layers.
// 2. (This class) Initialize embedding managers and other common state.
// 3. (Subclass) Initialize any non-recurrent layers.
//
// Concretely, the subclass's Initialize() should first add recurrent layers,
// then call InitializeBase(), and finally finish initializing. Evaluation is
// simpler: the subclass's Evaluate() may call EvaluateBase() at any time.
//
// Note: Network unit initialization is similarly interleaved between base and
// subclasses in the Python codebase; see NetworkUnitInterface.get_layer_size()
// and the "init_layers" argument to NetworkUnitInterface.__init__().
class NetworkUnitBase : public NetworkUnit {
public:
// Initializes common state as configured in the |component_spec|. Retrieves
// pre-trained embedding matrices from the |variable_store|. Looks up linked
// embeddings in the |network_state_manager|, which must contain all recurrent
// layers. Requests any required extensions from the |extension_manager|. If
// |use_concatenated_input| is true, prepares to concatenate input embeddings
// in EvaluateBase(). On error, returns non-OK.
tensorflow::Status InitializeBase(bool use_concatenated_input,
const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager);
// Resets the fixed and linked embeddings in the |session_state| using its
// network states and the |compute_session|. Requires that InitializeBase()
// was called. If this was prepared for concatenation (see InitializeBase())
// and if |concatenated_input| is non-null, points it at the concatenation of
// the fixed and linked embeddings. Otherwise, no concatenation occurs. On
// error, returns non-OK.
tensorflow::Status EvaluateBase(SessionState *session_state,
ComputeSession *compute_session,
Vector<float> *concatenated_input) const;
// Accessors. All require that InitializeBase() was called.
const FixedEmbeddingManager &fixed_embedding_manager() const;
const LinkedEmbeddingManager &linked_embedding_manager() const;
size_t num_actions() const { return num_actions_; }
size_t concatenated_input_dim() const { return concatenated_input_dim_; }
private:
// Returns the concatenation of the fixed and linked embeddings in the
// |seesion_state|. Requires that |use_concatenated_input_| is true.
Vector<float> ConcatenateInput(SessionState *session_state) const;
// Managers for fixed and linked embeddings in this component.
FixedEmbeddingManager fixed_embedding_manager_;
LinkedEmbeddingManager linked_embedding_manager_;
// Fixed and linked embeddings.
SharedExtensionHandle<FixedEmbeddings> fixed_embeddings_handle_;
SharedExtensionHandle<LinkedEmbeddings> linked_embeddings_handle_;
// Number of actions supported by the transition system.
size_t num_actions_ = 0;
// Sum of dimensions of all fixed and linked embeddings.
size_t concatenated_input_dim_ = 0;
// Whether to concatenate the input embeddings.
bool use_concatenated_input_ = false;
// Handle of the vector that holds the concatenated input, or invalid if no
// concatenation is required.
LocalVectorHandle<float> concatenated_input_handle_;
};
// Implementation details below.
inline const FixedEmbeddingManager &NetworkUnitBase::fixed_embedding_manager()
const {
return fixed_embedding_manager_;
}
inline const LinkedEmbeddingManager &NetworkUnitBase::linked_embedding_manager()
const {
return linked_embedding_manager_;
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_NETWORK_UNIT_BASE_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit_base.h"
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::_;
using ::testing::Invoke;
using ::testing::Return;
// Dimensions of the layers in the network.
static constexpr size_t kPreviousDim = 77;
static constexpr size_t kRecurrentDim = 123;
// Contents of the layers in the network.
static constexpr float kPreviousValue = -2.75;
static constexpr float kRecurrentValue = 6.25;
// Number of steps taken in each component.
static constexpr size_t kNumSteps = 10;
// A trivial network unit that exposes the concatenated inputs. Note that
// NetworkUnitBase does not override the interface methods, so we need a
// concrete subclass for testing.
class FooNetwork : public NetworkUnitBase {
public:
void RequestConcatenation() { request_concatenation_ = true; }
void ProvideConcatenatedInput() { provide_concatenated_input_ = true; }
Vector<float> concatenated_input() const { return concatenated_input_; }
// Implements NetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
TF_RETURN_IF_ERROR(network_state_manager->AddLayer(
"recurrent_layer", kRecurrentDim, &recurrent_handle_));
return InitializeBase(request_concatenation_, component_spec,
variable_store, network_state_manager,
extension_manager);
}
string GetLogitsName() const override { return ""; }
tensorflow::Status Evaluate(size_t unused_step_index,
SessionState *session_state,
ComputeSession *compute_session) const override {
return EvaluateBase(
session_state, compute_session,
provide_concatenated_input_ ? &concatenated_input_ : nullptr);
}
private:
bool request_concatenation_ = false;
bool provide_concatenated_input_ = false;
LayerHandle<float> recurrent_handle_;
mutable Vector<float> concatenated_input_; // Evaluate() sets this
};
class NetworkUnitBaseTest : public NetworkTestBase {
protected:
// Initializes the |network_unit_| based on the |component_spec_text| and
// evaluates it. On error, returns non-OK.
tensorflow::Status Run(const string &component_spec_text) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
component_spec.set_name(kTestComponentName);
AddComponent("previous_component");
AddLayer("previous_layer", kPreviousDim);
AddComponent(kTestComponentName);
TF_RETURN_IF_ERROR(
network_unit_.Initialize(component_spec, &variable_store_,
&network_state_manager_, &extension_manager_));
// Create and populate the network states.
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
StartComponent(kNumSteps);
FillLayer("previous_component", "previous_layer", kPreviousValue);
FillLayer(kTestComponentName, "recurrent_layer", kRecurrentValue);
session_state_.extensions.Reset(&extension_manager_);
// Neither FooNetwork nor NetworkUnitBase look at the step index, so use an
// arbitrary value.
return network_unit_.Evaluate(0, &session_state_, &compute_session_);
}
FooNetwork network_unit_;
std::vector<std::vector<float>> concatenated_inputs_;
};
// Tests that NetworkUnitBase produces an empty vector when concatenating and
// there are no input embeddings.
TEST_F(NetworkUnitBaseTest, ConcatenateNoInputs) {
network_unit_.RequestConcatenation();
network_unit_.ProvideConcatenatedInput();
TF_ASSERT_OK(Run(""));
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 0);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 0);
EXPECT_EQ(network_unit_.num_actions(), 0);
EXPECT_EQ(network_unit_.concatenated_input_dim(), 0);
EXPECT_TRUE(network_unit_.concatenated_input().empty());
}
// Tests that NetworkUnitBase produces a copy of the single input embedding when
// concatenating a single fixed channel.
TEST_F(NetworkUnitBaseTest, ConcatenateOneFixedChannel) {
const float kEmbedding = 1.5;
const float kFeature = 0.5;
const size_t kDim = 13;
const string kSpec = R"(num_actions: 42
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
})";
AddFixedEmbeddingMatrix(0, 11, kDim, kEmbedding);
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{1, kFeature}})));
const float kValue = kEmbedding * kFeature;
network_unit_.RequestConcatenation();
network_unit_.ProvideConcatenatedInput();
TF_ASSERT_OK(Run(kSpec));
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 1);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 0);
EXPECT_EQ(network_unit_.num_actions(), 42);
EXPECT_EQ(network_unit_.concatenated_input_dim(), kDim);
ExpectVector(network_unit_.concatenated_input(),
network_unit_.concatenated_input_dim(), kValue);
}
// Tests that NetworkUnitBase does not concatenate if concatenation is requested
// and the concatenated input vector is not provided.
TEST_F(NetworkUnitBaseTest, ConcatenatedInputVectorNotProvided) {
const float kEmbedding = 1.5;
const float kFeature = 0.5;
const size_t kDim = 13;
const string kSpec = R"(num_actions: 37
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
})";
AddFixedEmbeddingMatrix(0, 11, kDim, kEmbedding);
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{1, kFeature}})));
network_unit_.RequestConcatenation();
TF_ASSERT_OK(Run(kSpec));
// Embedding managers and other config is set up properly.
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 1);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 0);
EXPECT_EQ(network_unit_.num_actions(), 37);
EXPECT_EQ(network_unit_.concatenated_input_dim(), kDim);
// But the concatenation was not performed.
EXPECT_TRUE(network_unit_.concatenated_input().empty());
}
// As above, but with the converse condition: does not request concatenation,
// but does provide the concatenated input vector.
TEST_F(NetworkUnitBaseTest, ConcatenationNotRequested) {
const float kEmbedding = 1.5;
const float kFeature = 0.5;
const size_t kDim = 13;
const string kSpec = R"(num_actions: 31
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
})";
AddFixedEmbeddingMatrix(0, 11, kDim, kEmbedding);
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{1, kFeature}})));
network_unit_.ProvideConcatenatedInput();
TF_ASSERT_OK(Run(kSpec));
// Embedding managers and other config is set up properly.
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 1);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 0);
EXPECT_EQ(network_unit_.num_actions(), 31);
EXPECT_EQ(network_unit_.concatenated_input_dim(), kDim);
// But the concatenation was not performed.
EXPECT_TRUE(network_unit_.concatenated_input().empty());
}
// Tests that NetworkUnitBase produces a copy of the single input embedding when
// concatenating a single linked channel.
TEST_F(NetworkUnitBaseTest, ConcatenateOneLinkedChannel) {
const string kSpec = R"(num_actions: 37
linked_feature {
embedding_dim: -1
source_component: 'previous_component'
source_layer: 'previous_layer'
size: 1
})";
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: 5"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
network_unit_.RequestConcatenation();
network_unit_.ProvideConcatenatedInput();
TF_ASSERT_OK(Run(kSpec));
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 0);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 1);
EXPECT_EQ(network_unit_.num_actions(), 37);
EXPECT_EQ(network_unit_.concatenated_input_dim(), kPreviousDim);
ExpectVector(network_unit_.concatenated_input(),
network_unit_.concatenated_input_dim(), kPreviousValue);
}
// Tests that NetworkUnitBase concatenates a fixed and linked channel in that
// order.
TEST_F(NetworkUnitBaseTest, ConcatenateOneChannelOfEachType) {
const float kEmbedding = 1.25;
const float kFeature = 0.75;
const size_t kFixedDim = 13;
const string kSpec = R"(num_actions: 77
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'previous_component'
source_layer: 'previous_layer'
size: 1
})";
AddFixedEmbeddingMatrix(0, 11, kFixedDim, kEmbedding);
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{1, kFeature}})));
const float kFixedValue = kEmbedding * kFeature;
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: 5"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
network_unit_.RequestConcatenation();
network_unit_.ProvideConcatenatedInput();
TF_ASSERT_OK(Run(kSpec));
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 1);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 1);
EXPECT_EQ(network_unit_.num_actions(), 77);
EXPECT_EQ(network_unit_.concatenated_input_dim(), kFixedDim + kPreviousDim);
// Check that each sub-segment is equal to one of the input embeddings.
const Vector<float> input = network_unit_.concatenated_input();
EXPECT_EQ(input.size(), network_unit_.concatenated_input_dim());
size_t index = 0;
size_t end = kFixedDim;
for (; index < end; ++index) EXPECT_EQ(input[index], kFixedValue);
end += kPreviousDim;
for (; index < end; ++index) EXPECT_EQ(input[index], kPreviousValue);
}
// Tests that NetworkUnitBase produces a properly-ordered concatenation of
// multiple fixed and linked channels, including a recurrent channel.
TEST_F(NetworkUnitBaseTest, ConcatenateMultipleChannelsOfEachType) {
const float kEmbedding0 = 1.25;
const float kEmbedding1 = -0.125;
const float kFeature0 = 0.75;
const float kFeature1 = -2.5;
const size_t kFixedDim0 = 13;
const size_t kFixedDim1 = 19;
const string kSpec = R"(num_actions: 99
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
}
fixed_feature {
vocabulary_size: 17
embedding_dim: 19
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'previous_component'
source_layer: 'previous_layer'
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'test_component'
source_layer: 'recurrent_layer'
size: 1
})";
AddFixedEmbeddingMatrix(0, 11, kFixedDim0, kEmbedding0);
AddFixedEmbeddingMatrix(1, 17, kFixedDim1, kEmbedding1);
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{1, kFeature0}})))
.WillOnce(Invoke(ExtractFeatures(1, {{1, kFeature1}})));
const float kFixedValue0 = kEmbedding0 * kFeature0;
const float kFixedValue1 = kEmbedding1 * kFeature1;
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: 5"})))
.WillOnce(Invoke(ExtractLinks(1, {"step_idx: 6"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
network_unit_.RequestConcatenation();
network_unit_.ProvideConcatenatedInput();
TF_ASSERT_OK(Run(kSpec));
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 2);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 2);
EXPECT_EQ(network_unit_.num_actions(), 99);
EXPECT_EQ(network_unit_.concatenated_input_dim(),
kFixedDim0 + kFixedDim1 + kPreviousDim + kRecurrentDim);
// Check that each sub-segment is equal to one of the input embeddings. For
// compatibility with the Python codebase, fixed channels must appear before
// linked channels, and among each type order follows the ComponentSpec.
const Vector<float> input = network_unit_.concatenated_input();
EXPECT_EQ(input.size(), network_unit_.concatenated_input_dim());
size_t index = 0;
size_t end = kFixedDim0;
for (; index < end; ++index) EXPECT_EQ(input[index], kFixedValue0);
end += kFixedDim1;
for (; index < end; ++index) EXPECT_EQ(input[index], kFixedValue1);
end += kPreviousDim;
for (; index < end; ++index) EXPECT_EQ(input[index], kPreviousValue);
end += kRecurrentDim;
for (; index < end; ++index) EXPECT_EQ(input[index], kRecurrentValue);
}
// Tests that NetworkUnitBase refuses to concatenate if there are non-embedded
// fixed embeddings.
TEST_F(NetworkUnitBaseTest, CannotConcatenateNonEmbeddedFixedFeatures) {
const string kBadSpec = R"(fixed_feature {
embedding_dim: -1
size: 1
})";
network_unit_.RequestConcatenation();
network_unit_.ProvideConcatenatedInput();
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr(
"Non-embedded fixed features cannot be concatenated"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit.h"
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Expects that the two pointers have the same address.
void ExpectSameAddress(const void *pointer1, const void *pointer2) {
EXPECT_EQ(pointer1, pointer2);
}
// A trivial implementation for tests.
class FooNetwork : public NetworkUnit {
public:
// Implements NetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return tensorflow::Status::OK();
}
string GetLogitsName() const override { return "foo_logits"; }
tensorflow::Status Evaluate(size_t step_index, SessionState *session_state,
ComputeSession *compute_session) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT(FooNetwork);
// Tests that a human-friendly error is produced for empty network units.
TEST(NetworkUnitTest, GetClassNameDegenerateName) {
ComponentSpec component_spec;
EXPECT_DEATH(NetworkUnit::GetClassName(component_spec),
"No network unit name for component spec");
}
// Tests that NetworkUnit::GetClassName() resolves names properly.
TEST(NetworkUnitTest, GetClassName) {
for (const string &registered_name :
{"FooNetwork",
"module.FooNetwork",
"some.long.path.to.module.FooNetwork"}) {
ComponentSpec component_spec;
component_spec.mutable_network_unit()->set_registered_name(registered_name);
EXPECT_EQ(NetworkUnit::GetClassName(component_spec), "FooNetwork");
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/operands.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
OperandHandle OperandManager::Add(const OperandSpec &spec) {
const size_t index = specs_.size();
specs_.push_back(spec);
switch (spec.type) {
case OperandType::kSingular:
handle_index_to_typed_index_.push_back(singular_spans_.size());
singular_spans_.emplace_back(singular_size_, spec.size);
singular_size_ += PadToAlignment(spec.size);
break;
case OperandType::kStepwise:
handle_index_to_typed_index_.push_back(stepwise_spans_.size());
stepwise_spans_.emplace_back(stepwise_stride_, spec.size);
stepwise_stride_ += PadToAlignment(spec.size);
break;
case OperandType::kPairwise:
handle_index_to_typed_index_.push_back(pairwise_sizes_.size());
pairwise_sizes_.push_back(spec.size);
break;
}
return OperandHandle(index);
}
void Operands::Reset(const OperandManager *manager,
size_t pre_allocate_num_steps) {
manager_ = manager;
handle_index_to_typed_index_ = manager_->handle_index_to_typed_index_;
stepwise_spans_ = manager_->stepwise_spans_;
stepwise_stride_ = manager_->stepwise_stride_;
pairwise_sizes_ = manager_->pairwise_sizes_;
// Allocate and parcel out singular operands.
singular_operands_.clear();
singular_operands_.reserve(manager_->singular_spans_.size());
singular_array_.Reserve(manager_->singular_size_);
char *data = singular_array_.view().data();
for (const auto &span : manager_->singular_spans_) {
singular_operands_.push_back(
MutableAlignedView(data + span.first, span.second));
}
// Pre-allocate and parcel out stepwise operands.
stepwise_operands_.clear();
stepwise_operands_.reserve(stepwise_spans_.size());
stepwise_array_.Reserve(stepwise_stride_ * pre_allocate_num_steps);
data = stepwise_array_.view().data();
for (const auto &span : stepwise_spans_) {
stepwise_operands_.push_back(MutableAlignedArea(
data + span.first, 0, span.second, stepwise_stride_));
}
// Create empty pairwise operands.
pairwise_operands_.clear();
pairwise_operands_.resize(pairwise_sizes_.size());
}
void Operands::AddSteps(size_t num_steps) {
AddStepwiseSteps(num_steps);
AddPairwiseSteps(num_steps);
}
void Operands::AddStepwiseSteps(size_t num_steps) {
if (stepwise_operands_.empty()) return;
// Make room for the new steps.
const size_t new_num_views = stepwise_operands_[0].num_views_ + num_steps;
const bool actually_reallocated =
stepwise_array_.Resize(new_num_views * stepwise_stride_);
// Update the base pointers for stepwise operands, if changed.
if (actually_reallocated) {
char *data = stepwise_array_.view().data();
for (size_t i = 0; i < stepwise_operands_.size(); ++i) {
stepwise_operands_[i].data_ = data + stepwise_spans_[i].first;
}
}
// Update the number of views in each stepwise operand.
for (MutableAlignedArea &operand : stepwise_operands_) {
operand.num_views_ = new_num_views;
}
}
void Operands::AddPairwiseSteps(size_t num_steps) {
if (pairwise_operands_.empty()) return;
const size_t new_num_steps = pairwise_operands_[0].num_views_ + num_steps;
// Set dimensions for each pairwise operand and accumulate their total stride.
size_t new_stride = 0;
for (size_t i = 0; i < pairwise_operands_.size(); ++i) {
const size_t new_view_size = new_num_steps * pairwise_sizes_[i];
pairwise_operands_[i].num_views_ = new_num_steps;
pairwise_operands_[i].view_size_ = new_view_size;
new_stride += PadToAlignment(new_view_size);
}
// Note that Reset() does not preserve the existing array and its contents.
// Although preserving existing data would be nice, it is complex because
// pairwise operands grow in both dimensions. In addition, users should be
// allocating pairwise operands in one shot for speed reasons, in which case
// there is no existing data anyways.
pairwise_array_.Reset(new_num_steps * new_stride);
// Set the new base pointer and stride on each pairwise operand.
char *data = pairwise_array_.view().data();
for (MutableAlignedArea &operand : pairwise_operands_) {
operand.data_ = data;
operand.view_stride_ = new_stride;
data += PadToAlignment(operand.view_size_);
}
DCHECK_EQ(data - pairwise_array_.view().data(), new_stride);
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for declaring and allocating operands. An operand is made up of
// aligned byte arrays, and can be used as an input, output, or intermediate
// value in some computation.
#ifndef DRAGNN_RUNTIME_OPERANDS_H_
#define DRAGNN_RUNTIME_OPERANDS_H_
#include <stddef.h>
#include <stdint.h>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/runtime/alignment.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Possible types of operands.
enum class OperandType {
// A single byte array. For example, an intermediate value that is computed
// once per transition step. Since it is not an output, the same storage
// could be reused across all steps.
kSingular,
// A sequence of identically-sized byte arrays, one per transition step. For
// example, a layer containing one activation vector per step.
kStepwise,
// A grid with one byte array for each pair of transition steps, including
// self pairings. The byte arrays are grouped and concatenated in "rows",
// forming one byte array per step. For example, if there are N steps and D
// bytes per pair, the operand would have N arrays of size N*D bytes. In a
// basic attention model with one "similarity" between pairs of steps, one
// might use a pairwise operand with D=sizeof(float). For best performance,
// use Operands::AddSteps() to allocate all steps at once when working with
// pairwise operands.
kPairwise,
};
// A specification of a operand.
struct OperandSpec {
// Creates a trivial specification.
OperandSpec() = default;
// Creates a specification with the |type| and |size|.
OperandSpec(OperandType type, size_t size) : type(type), size(size) {}
// Type of the operand.
OperandType type = OperandType::kSingular;
// Size of each aligned byte array in the operand.
size_t size = 0;
};
// An opaque handle to an operand.
class OperandHandle;
// A class that manages a set of operand specifications and associates each
// operand with a handle. Operand contents can be retrieved using these
// handles; see Operands below.
class OperandManager {
public:
// Creates an empty manager.
OperandManager() = default;
// Adds an operand configured according to the |spec| and returns its handle.
OperandHandle Add(const OperandSpec &spec);
// Accessors.
const OperandSpec &spec(OperandHandle handle) const;
private:
friend class Operands;
// Specification of each operand.
std::vector<OperandSpec> specs_;
// Mapping from the handle index of an operand to its index amongst operands
// of the same type.
std::vector<size_t> handle_index_to_typed_index_;
// Span of each singular operand, as a (start-offset,size) pair, relative to
// the byte array containing all singular operands.
std::vector<std::pair<size_t, size_t>> singular_spans_;
// Span of each stepwise operand, as a (start-offset,size) pair, relative to
// the byte array for each step.
std::vector<std::pair<size_t, size_t>> stepwise_spans_;
// Size of each pairwise operand.
std::vector<size_t> pairwise_sizes_;
// Number of bytes used by all singular operands, including alignment padding.
size_t singular_size_ = 0;
// Number of bytes used by all stepwise operands on each step, including
// alignment padding.
size_t stepwise_stride_ = 0;
};
// A set of operands. The structure of the operands is configured by an
// OperandManager, and operand contents can be accessed using the handles
// produced by the manager.
//
// Multiple Operands instances can share the same OperandManager. In addition,
// an Operands instance can be reused by repeatedly Reset()-ing it, potentially
// with different OperandManagers. Such reuse can reduce allocation overhead.
class Operands {
public:
// Creates an empty set.
Operands() = default;
// Resets this to the operands defined by the |manager|. The |manager| must
// live until this is destroyed or Reset() again, and should not be modified
// during that time. Stepwise and pairwise operands start with 0 steps; use
// AddStep() to extend them. Pre-allocates stepwise operands so that they
// will not be reallocated during the first |pre_allocate_num_steps| calls to
// AddStep(). Invalidates all previously-returned operands.
void Reset(const OperandManager *manager, size_t pre_allocate_num_steps);
// Extends stepwise and pairwise operands by one or more steps. Requires that
// Reset() was called. Invalidates any previously-returned views of stepwise
// and pairwise operands. Preserves data for pre-existing steps of stepwise
// operands, but not for pre-existing pairwise operands. In general, pairwise
// operands should be allocated in one shot, not incrementally.
void AddStep() { AddSteps(1); }
void AddSteps(size_t num_steps);
// Returns the singular operand associated with the |handle|. The returned
// view is invalidated by Reset().
MutableAlignedView GetSingular(OperandHandle handle) const;
// Returns the stepwise operand associated with the |handle|. The returned
// area is invalidated by Reset() and AddStep().
MutableAlignedArea GetStepwise(OperandHandle handle) const;
// Returns the pairwise operand associated with the |handle|. The returned
// area is invalidated by Reset() and AddStep().
MutableAlignedArea GetPairwise(OperandHandle handle) const;
private:
// Extends stepwise operands only; see AddSteps().
void AddStepwiseSteps(size_t num_steps);
// Extends pairwise operands only; see AddSteps().
void AddPairwiseSteps(size_t num_steps);
// Manager of the operands in this set.
const OperandManager *manager_ = nullptr;
// Cached members from the |manager_|.
tensorflow::gtl::ArraySlice<size_t> handle_index_to_typed_index_;
tensorflow::gtl::ArraySlice<std::pair<size_t, size_t>> stepwise_spans_;
size_t stepwise_stride_ = 0;
tensorflow::gtl::ArraySlice<size_t> pairwise_sizes_;
// Byte arrays holding operands of each type. Storage is separated because
// each type grows differently with the number of steps.
UniqueAlignedArray singular_array_;
UniqueAlignedArray stepwise_array_;
UniqueAlignedArray pairwise_array_;
// Lists of operands of each type.
std::vector<MutableAlignedView> singular_operands_;
std::vector<MutableAlignedArea> stepwise_operands_;
std::vector<MutableAlignedArea> pairwise_operands_;
};
// Implementation details below.
// An opaque handle to an operand.
class OperandHandle {
public:
// Creates an invalid handle.
OperandHandle() = default;
private:
friend class OperandManager;
friend class Operands;
// Creates a handle that points to the |index|.
explicit OperandHandle(size_t index) : index_(index) {}
// Index of the operand in its manager.
size_t index_ = SIZE_MAX;
};
inline const OperandSpec &OperandManager::spec(OperandHandle handle) const {
return specs_[handle.index_];
}
inline MutableAlignedView Operands::GetSingular(OperandHandle handle) const {
DCHECK(manager_->spec(handle).type == OperandType::kSingular)
<< "Actual type: " << static_cast<int>(manager_->spec(handle).type);
DCHECK_LE(handle.index_, handle_index_to_typed_index_.size());
return singular_operands_[handle_index_to_typed_index_[handle.index_]];
}
inline MutableAlignedArea Operands::GetStepwise(OperandHandle handle) const {
DCHECK(manager_->spec(handle).type == OperandType::kStepwise)
<< "Actual type: " << static_cast<int>(manager_->spec(handle).type);
DCHECK_LE(handle.index_, handle_index_to_typed_index_.size());
return stepwise_operands_[handle_index_to_typed_index_[handle.index_]];
}
inline MutableAlignedArea Operands::GetPairwise(OperandHandle handle) const {
DCHECK(manager_->spec(handle).type == OperandType::kPairwise)
<< "Actual type: " << static_cast<int>(manager_->spec(handle).type);
DCHECK_LE(handle.index_, handle_index_to_typed_index_.size());
return pairwise_operands_[handle_index_to_typed_index_[handle.index_]];
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_OPERANDS_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/operands.h"
#include <string.h>
#include <tuple>
#include <utility>
#include <vector>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Expects that the two pointers are the same.
void ExpectSameAddress(const void *pointer1, const void *pointer2) {
EXPECT_EQ(pointer1, pointer2);
}
// Sets the |vector| to |size| copies of the |value|.
template <class T>
void Fill(MutableVector<T> vector, size_t size, T value) {
ASSERT_EQ(vector.size(), size);
for (T &element : vector) element = value;
}
// Expects that the |vector| contains |size| copies of the |expected_value|.
template <class T>
void ExpectFilled(Vector<T> vector, size_t size, T expected_value) {
ASSERT_EQ(vector.size(), size);
for (const T element : vector) EXPECT_EQ(element, expected_value);
}
// Tests that OperandManager can add operands and remember their configuration.
TEST(OperandManagerTest, Add) {
OperandManager manager;
const OperandHandle handle1 = manager.Add({OperandType::kSingular, 7});
const OperandHandle handle2 = manager.Add({OperandType::kStepwise, 11});
EXPECT_EQ(manager.spec(handle1).type, OperandType::kSingular);
EXPECT_EQ(manager.spec(handle1).size, 7);
EXPECT_EQ(manager.spec(handle2).type, OperandType::kStepwise);
EXPECT_EQ(manager.spec(handle2).size, 11);
}
// Tests that Operands contains operands whose dimensions match its manager.
TEST(OperandsTest, Dimensions) {
const size_t kDim1 = 3, kDim2 = 41, kDim3 = 19, kDim4 = 77;
OperandManager manager;
const OperandHandle handle1 =
manager.Add({OperandType::kSingular, kDim1 * sizeof(float)});
const OperandHandle handle2 =
manager.Add({OperandType::kStepwise, kDim2 * sizeof(double)});
const OperandHandle handle3 =
manager.Add({OperandType::kSingular, kDim3 * sizeof(float)});
const OperandHandle handle4 =
manager.Add({OperandType::kStepwise, kDim4 * sizeof(int)});
AlignedView view;
AlignedArea area;
Operands operands;
operands.Reset(&manager, 10);
view = operands.GetSingular(handle1);
EXPECT_EQ(view.size(), kDim1 * sizeof(float));
EXPECT_EQ(Vector<float>(view).size(), kDim1);
area = operands.GetStepwise(handle2);
EXPECT_EQ(area.num_views(), 0); // no steps yet
EXPECT_EQ(area.view_size(), kDim2 * sizeof(double));
EXPECT_EQ(Matrix<double>(area).num_rows(), 0); // starts with no steps
EXPECT_EQ(Matrix<double>(area).num_columns(), kDim2);
view = operands.GetSingular(handle3);
EXPECT_EQ(view.size(), kDim3 * sizeof(float));
EXPECT_EQ(Vector<float>(view).size(), kDim3);
area = operands.GetStepwise(handle4);
EXPECT_EQ(area.num_views(), 0); // no steps yet
EXPECT_EQ(area.view_size(), kDim4 * sizeof(int));
EXPECT_EQ(Matrix<int>(area).num_rows(), 0); // starts with no steps
EXPECT_EQ(Matrix<int>(area).num_columns(), kDim4);
}
// Tests that Operands can incrementally extend stepwise operands while
// preserving existing values.
TEST(OperandsTest, AddStepToStepwise) {
const size_t kDim1 = 23, kDim2 = 29;
OperandManager manager;
const OperandHandle handle1 =
manager.Add({OperandType::kStepwise, kDim1 * sizeof(double)});
const OperandHandle handle2 =
manager.Add({OperandType::kStepwise, kDim2 * sizeof(int)});
Operands operands;
operands.Reset(&manager, 10);
// Repeatedly add a step and fill it with values.
for (int i = 0; i < 100; ++i) {
operands.AddStep();
Fill(MutableVector<double>(operands.GetStepwise(handle1).view(i)), kDim1,
1000.0 + i);
Fill(MutableVector<int>(operands.GetStepwise(handle2).view(i)), kDim2,
2000 + i);
}
// Check that data from earlier steps is preserved across reallocations.
for (int i = 0; i < 100; ++i) {
ExpectFilled(Vector<double>(operands.GetStepwise(handle1).view(i)), kDim1,
1000.0 + i);
ExpectFilled(Vector<int>(operands.GetStepwise(handle2).view(i)), kDim2,
2000 + i);
}
}
// Tests that Operands can add multiple steps at once.
TEST(OperandsTest, AddStepsToStepwise) {
const size_t kDim1 = 23, kDim2 = 29;
OperandManager manager;
const OperandHandle handle1 =
manager.Add({OperandType::kStepwise, kDim1 * sizeof(double)});
const OperandHandle handle2 =
manager.Add({OperandType::kStepwise, kDim2 * sizeof(int)});
Operands operands;
operands.Reset(&manager, 10);
// Repeatedly add blocks of steps and fill them with values.
for (int i = 0; i < 100; ++i) {
if (i % 10 == 0) operands.AddSteps(10); // occasionally add a block
Fill(MutableVector<double>(operands.GetStepwise(handle1).view(i)), kDim1,
1000.0 + i);
Fill(MutableVector<int>(operands.GetStepwise(handle2).view(i)), kDim2,
2000 + i);
}
// Check that data from earlier steps is preserved across reallocations.
for (int i = 0; i < 100; ++i) {
ExpectFilled(Vector<double>(operands.GetStepwise(handle1).view(i)), kDim1,
1000.0 + i);
ExpectFilled(Vector<int>(operands.GetStepwise(handle2).view(i)), kDim2,
2000 + i);
}
}
// Tests that Operands can add multiple steps to a pairwise operand.
TEST(OperandsTest, AddStepsPairwise) {
const size_t kDim1 = 4, kDim2 = 31;
OperandManager manager;
const OperandHandle handle1 = manager.Add({OperandType::kPairwise, kDim1});
const OperandHandle handle2 = manager.Add({OperandType::kPairwise, kDim2});
Operands operands;
operands.Reset(&manager, 10);
{ // A 1x1 pairwise operand.
operands.AddSteps(1);
const MutableAlignedArea area1 = operands.GetPairwise(handle1);
const MutableAlignedArea area2 = operands.GetPairwise(handle2);
EXPECT_EQ(area1.num_views(), 1);
EXPECT_EQ(area2.num_views(), 1);
EXPECT_EQ(area1.view_size(), kDim1);
EXPECT_EQ(area2.view_size(), kDim2);
// Write to operands to test the validity of the underlying memory region.
memset(area1.view(0).data(), 0, kDim1);
memset(area2.view(0).data(), 0, kDim2);
}
{ // A 10x10 pairwise operand.
operands.AddSteps(9);
const MutableAlignedArea area1 = operands.GetPairwise(handle1);
const MutableAlignedArea area2 = operands.GetPairwise(handle2);
EXPECT_EQ(area1.num_views(), 10);
EXPECT_EQ(area2.num_views(), 10);
EXPECT_EQ(area1.view_size(), 10 * kDim1);
EXPECT_EQ(area2.view_size(), 10 * kDim2);
// Infer the stride by comparing pointers between consecutive views.
const size_t expected_stride =
PadToAlignment(10 * kDim1) + PadToAlignment(10 * kDim2);
EXPECT_EQ(area1.view(1).data() - area1.view(0).data(), expected_stride);
EXPECT_EQ(area2.view(1).data() - area2.view(0).data(), expected_stride);
// Write to operands to test the validity of the underlying memory region.
memset(area1.view(9).data(), 0, 10 * kDim1);
memset(area2.view(9).data(), 0, 10 * kDim2);
}
}
// Tests that Operands can be reused by resetting them repeatedly, possibly
// switching between different managers.
TEST(OperandsTest, ResetWithDifferentManagers) {
std::vector<OperandManager> managers;
std::vector<std::tuple<OperandHandle, OperandHandle, OperandHandle>> handles;
for (int dim = 0; dim < 10; ++dim) {
managers.emplace_back();
handles.emplace_back(
managers.back().Add({OperandType::kSingular, dim * sizeof(double)}),
managers.back().Add({OperandType::kStepwise, dim * sizeof(int)}),
managers.back().Add({OperandType::kPairwise, dim * sizeof(float)}));
}
Operands operands;
for (int trial = 0; trial < 10; ++trial) {
for (int dim = 0; dim < 10; ++dim) {
operands.Reset(&managers[dim], 10);
const OperandHandle singular_handle = std::get<0>(handles[dim]);
const OperandHandle stepwise_handle = std::get<1>(handles[dim]);
const OperandHandle pairwise_handle = std::get<2>(handles[dim]);
// Fill the singular operand.
Fill(MutableVector<double>(operands.GetSingular(singular_handle)), dim,
100.0 * trial + dim);
// Check the singular operands.
ExpectFilled(Vector<double>(operands.GetSingular(singular_handle)), dim,
100.0 * trial + dim);
// Repeatedly add a step and fill it with values.
for (int step = 0; step < 100; ++step) {
operands.AddStep();
Fill(MutableVector<int>(
operands.GetStepwise(stepwise_handle).view(step)),
dim, 1000 * trial + 100 * dim + step);
}
// Check that data from earlier steps is preserved across reallocations.
for (int step = 0; step < 100; ++step) {
ExpectFilled(
Vector<int>(operands.GetStepwise(stepwise_handle).view(step)), dim,
1000 * trial + 100 * dim + step);
}
// Check the dimensions of pairwise operands.
Matrix<float> pairwise(operands.GetPairwise(pairwise_handle));
EXPECT_EQ(pairwise.num_rows(), 100);
EXPECT_EQ(pairwise.num_columns(), 100 * dim);
}
}
}
// Tests that one OperandManager can be shared simultaneously between multiple
// Operands instances.
TEST(OperandsTest, SharedManager) {
const size_t kDim = 17;
OperandManager manager;
const OperandHandle singular_handle =
manager.Add({OperandType::kSingular, kDim * sizeof(double)});
const OperandHandle stepwise_handle =
manager.Add({OperandType::kStepwise, kDim * sizeof(int)});
std::vector<Operands> operands_vec(10);
for (Operands &operands : operands_vec) operands.Reset(&manager, 10);
// Fill all singular operands.
for (int trial = 0; trial < operands_vec.size(); ++trial) {
const Operands &operands = operands_vec[trial];
Fill(MutableVector<double>(operands.GetSingular(singular_handle)), kDim,
3.0 * trial);
}
// Check all singular operands.
for (int trial = 0; trial < operands_vec.size(); ++trial) {
const Operands &operands = operands_vec[trial];
ExpectFilled(Vector<double>(operands.GetSingular(singular_handle)), kDim,
3.0 * trial);
}
// Fill all stepwise operands. Interleave operations on the operands on each
// step, so all operands are "active" at the same time.
for (int step = 0; step < 100; ++step) {
for (int trial = 0; trial < 10; ++trial) {
Operands &operands = operands_vec[trial];
operands.AddStep();
Fill(MutableVector<int>(operands.GetStepwise(stepwise_handle).view(step)),
kDim, trial * 999 + step);
}
}
// Check all stepwise operands.
for (int step = 0; step < 100; ++step) {
for (int trial = 0; trial < 10; ++trial) {
const Operands &operands = operands_vec[trial];
ExpectFilled(
Vector<int>(operands.GetStepwise(stepwise_handle).view(step)), kDim,
trial * 999 + step);
}
}
}
// Tests that an Operands uses all of the pre-allocated steps and reallocates
// exactly when it exhausts the pre-allocated array.
TEST(OperandsTest, UsesPreAllocatedSteps) {
const size_t kBytes = 5;
const size_t kPreAllocateNumSteps = 10;
OperandManager manager;
const OperandHandle handle = manager.Add({OperandType::kStepwise, kBytes});
Operands operands;
operands.Reset(&manager, kPreAllocateNumSteps);
// The first N steps fit exactly in the pre-allocated array. Access the base
// of the stepwise array via the first view.
operands.AddStep();
char *const pre_allocated_data = operands.GetStepwise(handle).view(0).data();
for (size_t step = 1; step < kPreAllocateNumSteps; ++step) {
operands.AddStep();
ASSERT_EQ(operands.GetStepwise(handle).view(0).data(), pre_allocated_data);
}
// The N+1'st step triggers a reallocation, which is guaranteed to yield a new
// pointer because it creates a separate array and copies into it.
operands.AddStep();
ASSERT_NE(operands.GetStepwise(handle).view(0).data(), pre_allocated_data);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Links to the previous step in the same component. Templated on a bool that
// indicates the direction that the transition system runs in.
template <bool left_to_right>
class RecurrentSequenceLinker : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const override;
tensorflow::Status Initialize(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) override;
tensorflow::Status GetLinks(size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const override;
};
template <bool left_to_right>
bool RecurrentSequenceLinker<left_to_right>::Supports(
const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const {
TransitionSystemTraits traits(component_spec);
// Here, fml="bias" and source_translator="history" are a DRAGNN recipe for
// linking to the previous transition step. More concretely,
// * "bias" always extracts index 0.
// * "history" subtracts the index it is given from (#steps - 1).
// Putting the two together, we link to (#steps - 1 - 0); i.e., the previous
// transition step.
return (channel.fml() == "bias" || channel.fml() == "bias(0)") &&
channel.source_component() == component_spec.name() &&
channel.source_translator() == "history" &&
traits.is_left_to_right == left_to_right && traits.is_sequential;
}
template <bool left_to_right>
tensorflow::Status RecurrentSequenceLinker<left_to_right>::Initialize(
const LinkedFeatureChannel &channel, const ComponentSpec &component_spec) {
return tensorflow::Status::OK();
}
template <bool left_to_right>
tensorflow::Status RecurrentSequenceLinker<left_to_right>::GetLinks(
size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const {
links->resize(source_num_steps);
if (left_to_right) {
int32 index = -1;
for (int32 &link : *links) link = index++;
} else {
int32 index = static_cast<int32>(source_num_steps) - 1;
for (int32 &link : *links) link = --index;
}
return tensorflow::Status::OK();
}
using LeftToRightRecurrentSequenceLinker = RecurrentSequenceLinker<true>;
using RightToLeftRecurrentSequenceLinker = RecurrentSequenceLinker<false>;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(LeftToRightRecurrentSequenceLinker);
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(RightToLeftRecurrentSequenceLinker);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns a ComponentSpec that the linker will support.
ComponentSpec MakeSupportedSpec() {
ComponentSpec component_spec;
component_spec.set_name("test_component");
component_spec.mutable_transition_system()->set_registered_name("shift-only");
LinkedFeatureChannel *channel = component_spec.add_linked_feature();
channel->set_fml("bias");
channel->set_source_component("test_component");
channel->set_source_translator("history");
return component_spec;
}
// Tests that the linker supports appropriate specs.
TEST(RecurrentSequenceLinkerTest, Supported) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "LeftToRightRecurrentSequenceLinker");
(*component_spec.mutable_transition_system()
->mutable_parameters())["left_to_right"] = "false";
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "RightToLeftRecurrentSequenceLinker");
channel.set_fml("bias(0)");
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "RightToLeftRecurrentSequenceLinker");
(*component_spec.mutable_transition_system()
->mutable_parameters())["left_to_right"] = "true";
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "LeftToRightRecurrentSequenceLinker");
}
// Tests that the linker requires the right transition system.
TEST(RecurrentSequenceLinkerTest, WrongTransitionSystem) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
component_spec.mutable_transition_system()->set_registered_name("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right FML.
TEST(RecurrentSequenceLinkerTest, WrongFml) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_fml("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires a recurrent link.
TEST(RecurrentSequenceLinkerTest, WrongSource) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_source_component("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right translator.
TEST(RecurrentSequenceLinkerTest, WrongTranslator) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_source_translator("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker can be initialized and used to extract links.
TEST(RecurrentSequenceLinkerTest, InitializeAndGetLinks) {
const ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
std::unique_ptr<SequenceLinker> linker;
TF_ASSERT_OK(SequenceLinker::New("LeftToRightRecurrentSequenceLinker",
channel, component_spec, &linker));
InputBatchCache input;
std::vector<int32> links = {123, 456, 789}; // gets overwritten
TF_ASSERT_OK(linker->GetLinks(10, &input, &links));
const std::vector<int32> expected_links = {-1, 0, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT_EQ(links, expected_links);
}
// Tests that the links are reversed for right-to-left components.
TEST(RecurrentSequenceLinkerTest, InitializeAndGetLinksRightToLeft) {
ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
std::unique_ptr<SequenceLinker> linker;
TF_ASSERT_OK(SequenceLinker::New("RightToLeftRecurrentSequenceLinker",
channel, component_spec, &linker));
InputBatchCache input;
std::vector<int32> links = {123, 456, 789}; // gets overwritten
TF_ASSERT_OK(linker->GetLinks(10, &input, &links));
const std::vector<int32> expected_links = {8, 7, 6, 5, 4, 3, 2, 1, 0, -1};
EXPECT_EQ(links, expected_links);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Applies a reversed identity function.
class ReversedSequenceLinker : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const override;
tensorflow::Status Initialize(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) override;
tensorflow::Status GetLinks(size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const override;
};
bool ReversedSequenceLinker::Supports(
const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const {
TransitionSystemTraits traits(component_spec);
// Note: Add more "||" clauses as needed.
return ((channel.fml() == "input.focus" &&
channel.source_translator() == "reverse-token") ||
(channel.fml() == "char-input.focus" &&
channel.source_translator() == "reverse-char")) &&
traits.is_sequential;
}
tensorflow::Status ReversedSequenceLinker::Initialize(
const LinkedFeatureChannel &channel, const ComponentSpec &component_spec) {
return tensorflow::Status::OK();
}
tensorflow::Status ReversedSequenceLinker::GetLinks(
size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const {
links->resize(source_num_steps);
int32 index = links->size();
for (int32 &link : *links) link = --index;
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(ReversedSequenceLinker);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns a ComponentSpec that the linker will support.
ComponentSpec MakeSupportedSpec() {
ComponentSpec component_spec;
component_spec.mutable_transition_system()->set_registered_name("shift-only");
LinkedFeatureChannel *channel = component_spec.add_linked_feature();
channel->set_fml("input.focus");
channel->set_source_translator("reverse-token");
return component_spec;
}
// Tests that the linker supports appropriate specs.
TEST(ReversedSequenceLinkerTest, Supported) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "ReversedSequenceLinker");
channel.set_fml("char-input.focus");
channel.set_source_translator("reverse-char");
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "ReversedSequenceLinker");
}
// Tests that the linker requires the right transition system.
TEST(IdentitySequenceLinkerTest, WrongTransitionSystem) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
component_spec.mutable_transition_system()->set_registered_name("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right FML.
TEST(ReversedSequenceLinkerTest, WrongFml) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_fml("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right translator.
TEST(ReversedSequenceLinkerTest, WrongTranslator) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_source_translator("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right combination of FML and translator.
TEST(ReversedSequenceLinkerTest, MismatchedFmlAndTranslator) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_fml("input.focus");
channel.set_source_translator("reverse-char");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
channel.set_fml("char-input.focus");
channel.set_source_translator("reverse-token");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker can be initialized and used to extract links.
TEST(ReversedSequenceLinkerTest, InitializeAndGetLinks) {
const ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
std::unique_ptr<SequenceLinker> linker;
TF_ASSERT_OK(SequenceLinker::New("ReversedSequenceLinker", channel,
component_spec, &linker));
InputBatchCache input;
std::vector<int32> links = {123, 456, 789}; // gets overwritten
TF_ASSERT_OK(linker->GetLinks(10, &input, &links));
const std::vector<int32> expected_links = {9, 8, 7, 6, 5, 4, 3, 2, 1, 0};
EXPECT_EQ(links, expected_links);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/component_transformation.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Transformer that selects the best component subclass for the ComponentSpec.
class SelectBestComponentTransformer : public ComponentTransformer {
public:
// Implements ComponentTransformer.
tensorflow::Status Transform(const string &component_type,
ComponentSpec *component_spec) override {
string best_component_type;
TF_RETURN_IF_ERROR(
Component::Select(*component_spec, &best_component_type));
component_spec->mutable_component_builder()->set_registered_name(
best_component_type);
if (component_type != best_component_type) {
LOG(INFO) << "Component '" << component_spec->name()
<< "' builder updated from " << component_type << " to "
<< best_component_type << ".";
} else {
VLOG(2) << "Component '" << component_spec->name() << "' builder type "
<< component_type << " unchanged.";
}
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(SelectBestComponentTransformer);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/extensions.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Base class for test components.
class TestComponentBase : public Component {
public:
// Partially implements Component.
tensorflow::Status Initialize(const ComponentSpec &, VariableStore *,
NetworkStateManager *,
ExtensionManager *) override {
return tensorflow::Status::OK();
}
tensorflow::Status Evaluate(SessionState *, ComputeSession *,
ComponentTrace *) const override {
return tensorflow::Status::OK();
}
bool PreferredTo(const Component &) const override { return false; }
};
// Supports components whose builder name includes "Foo".
class ContainsFoo : public TestComponentBase {
public:
// Implements Component.
bool Supports(const ComponentSpec &,
const string &normalized_builder_name) const override {
return normalized_builder_name.find("Foo") != string::npos;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT(ContainsFoo);
// Supports components whose builder name includes "Bar".
class ContainsBar : public TestComponentBase {
public:
// Implements Component.
bool Supports(const ComponentSpec &,
const string &normalized_builder_name) const override {
return normalized_builder_name.find("Bar") != string::npos;
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT(ContainsBar);
// Tests that a spec with an unknown builder name causes an error.
TEST(SelectBestComponentTransformerTest, Unknown) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("unknown");
EXPECT_THAT(ComponentTransformer::ApplyAll(&component_spec),
test::IsErrorWithSubstr("Could not find a best"));
}
// Tests that a spec with builder "Foo" is changed to "ContainsFoo".
TEST(SelectBestComponentTransformerTest, ChangeToContainsFoo) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("Foo");
ComponentSpec expected_spec = component_spec;
expected_spec.mutable_component_builder()->set_registered_name("ContainsFoo");
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(expected_spec));
}
// Tests that a spec with builder "Bar" is changed to "ContainsBar".
TEST(SelectBestComponentTransformerTest, ChangeToContainsBar) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("Bar");
ComponentSpec expected_spec = component_spec;
expected_spec.mutable_component_builder()->set_registered_name("ContainsBar");
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(expected_spec));
}
// Tests that a spec with builder "FooBar" causes a conflict.
TEST(SelectBestComponentTransformerTest, Conflict) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("FooBar");
EXPECT_THAT(
ComponentTransformer::ApplyAll(&component_spec),
test::IsErrorWithSubstr("both think they should be dis-preferred"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/core/component_registry.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
std::function<int(int, int, int)> SequenceBackend::GetStepLookupFunction(
const string &method) {
if (method == "reverse-char" || method == "reverse-token") {
// Reverses the |index| in the sequence. We are agnostic to whether the
// input is a sequence of tokens or chars.
return [this](int unused_batch_index, int unused_beam_index, int index) {
index = sequence_size_ - index - 1;
return index >= 0 && index < sequence_size_ ? index : -1;
};
}
LOG(FATAL) << "[" << name_ << "] Unknown step lookup function: " << method;
}
void SequenceBackend::InitializeComponent(const ComponentSpec &spec) {
name_ = spec.name();
}
void SequenceBackend::InitializeData(
const std::vector<std::vector<const TransitionState *>> &parent_states,
int max_beam_size, InputBatchCache *input_data) {
// Store the |parent_states| for forwarding to downstream components.
parent_states_ = parent_states;
}
std::vector<std::vector<const TransitionState *>> SequenceBackend::GetBeam() {
// Forward the states of the previous component.
return parent_states_;
}
int SequenceBackend::GetSourceBeamIndex(int current_index, int batch) const {
// Forward the |current_index| to the previous component.
return current_index;
}
int SequenceBackend::GetBeamIndexAtStep(int step, int current_index,
int batch) const {
// Always return 0 since there is only one beam.
return 0;
}
std::vector<std::vector<ComponentTrace>> SequenceBackend::GetTraceProtos()
const {
// Return a single trace, since the beam and batch sizes are fixed at 1.
return {{ComponentTrace()}};
}
string SequenceBackend::Name() const { return name_; }
int SequenceBackend::BeamSize() const { return 1; }
int SequenceBackend::BatchSize() const { return 1; }
bool SequenceBackend::IsReady() const { return true; }
bool SequenceBackend::IsTerminal() const { return true; }
void SequenceBackend::FinalizeData() {}
void SequenceBackend::ResetComponent() {}
void SequenceBackend::InitializeTracing() {}
void SequenceBackend::DisableTracing() {}
int SequenceBackend::StepsTaken(int batch_index) const {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
bool SequenceBackend::AdvanceFromPrediction(const float *transition_matrix,
int num_items, int num_actions) {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
void SequenceBackend::AdvanceFromOracle() {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
std::vector<std::vector<std::vector<Label>>> SequenceBackend::GetOracleLabels()
const {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
int SequenceBackend::GetFixedFeatures(
std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights, int channel_id) const {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
int SequenceBackend::BulkGetFixedFeatures(
const BulkFeatureExtractor &extractor) {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
void SequenceBackend::BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
void SequenceBackend::BulkEmbedDenseFixedFeatures(
const vector<const float *> &per_channel_embeddings,
float *embedding_output, int embedding_output_size,
int *offset_array_output, int offset_array_size) {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
int SequenceBackend::BulkDenseFeatureSize() const {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
std::vector<LinkFeatures> SequenceBackend::GetRawLinkFeatures(
int channel_id) const {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
void SequenceBackend::AddTranslatedLinkFeaturesToTrace(
const std::vector<LinkFeatures> &features, int channel_id) {
LOG(FATAL) << "[" << name_ << "] Not supported";
}
REGISTER_DRAGNN_COMPONENT(SequenceBackend);
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SEQUENCE_BACKEND_H_
#define DRAGNN_RUNTIME_SEQUENCE_BACKEND_H_
#include <functional>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/component.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/util/label.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "syntaxnet/base.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Runtime-only component backend for sequence-based models. This is not used
// at training time, and provides trivial implementations of most methods. This
// is intended to be used with alternative feature extraction approaches, such
// as SequenceExtractor.
class SequenceBackend : public dragnn::Component {
public:
// Sets the size of the sequence in the current input.
void SetSequenceSize(int size) { sequence_size_ = size; }
// Implements dragnn::Component.
std::function<int(int, int, int)> GetStepLookupFunction(
const string &method) override;
void InitializeComponent(const ComponentSpec &spec) override;
void InitializeData(
const std::vector<std::vector<const TransitionState *>> &parent_states,
int max_beam_size, InputBatchCache *input_data) override;
std::vector<std::vector<const TransitionState *>> GetBeam() override;
int GetSourceBeamIndex(int current_index, int batch) const override;
int GetBeamIndexAtStep(int step, int current_index, int batch) const override;
std::vector<std::vector<ComponentTrace>> GetTraceProtos() const override;
string Name() const override;
int BeamSize() const override;
int BatchSize() const override;
bool IsReady() const override;
bool IsTerminal() const override;
void FinalizeData() override;
void ResetComponent() override;
void InitializeTracing() override;
void DisableTracing() override;
// Not implemented, crashes when called.
int StepsTaken(int batch_index) const override;
// Not implemented, crashes when called.
bool AdvanceFromPrediction(const float *transition_matrix, int num_items,
int num_actions) override;
// Not implemented, crashes when called.
void AdvanceFromOracle() override;
// Not implemented, crashes when called.
std::vector<std::vector<std::vector<Label>>> GetOracleLabels() const override;
// Not implemented, crashes when called.
int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights,
int channel_id) const override;
// Not implemented, crashes when called.
int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override;
// Not implemented, crashes when called.
void BulkEmbedFixedFeatures(
int batch_size_padding, int num_steps_padding, int output_array_size,
const vector<const float *> &per_channel_embeddings,
float *embedding_output) override;
// Not implemented, crashes when called.
void BulkEmbedDenseFixedFeatures(
const vector<const float *> &per_channel_embeddings,
float *embedding_output, int embedding_output_size,
int *offset_array_output, int offset_array_size) override;
// Not implemented, crashes when called.
int BulkDenseFeatureSize() const override;
// Not implemented, crashes when called.
std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override;
// Not implemented, crashes when called.
void AddTranslatedLinkFeaturesToTrace(
const std::vector<LinkFeatures> &features, int channel_id) override;
private:
// Name of the component that this backend supports.
string name_;
// Size of the current input sequence.
int sequence_size_ = 0;
// Parent states passed to InitializeData(), and passed along in GetBeam().
std::vector<std::vector<const TransitionState *>> parent_states_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_SEQUENCE_BACKEND_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/components/util/bulk_feature_extractor.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/interfaces/transition_state.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Tests that the "reverse-*" step lookup functions ignore the batch and beam
// indices and return -1 if the sequence size was never set.
TEST(SequenceBackendTest, ReverseCharUninitialized) {
for (const string &reverse_method : {"reverse-char", "reverse-token"}) {
SequenceBackend backend;
const std::function<int(int, int, int)> reverse =
backend.GetStepLookupFunction(reverse_method);
EXPECT_EQ(reverse(0, 0, 0), -1);
EXPECT_EQ(reverse(1, 1, 1), -1);
EXPECT_EQ(reverse(-1, -1, -1), -1);
EXPECT_EQ(reverse(0, 0, 9999), -1);
EXPECT_EQ(reverse(0, 0, -9999), -1);
}
}
// Tests that the "reverse-*" step lookup functions ignore the batch and beam
// indices and return the reverse of the step index w.r.t. the most recent call
// to SetSequenceSize().
TEST(SequenceBackendTest, ReverseCharAfterSetSequenceSize) {
for (const string &reverse_method : {"reverse-char", "reverse-token"}) {
SequenceBackend backend;
const std::function<int(int, int, int)> reverse =
backend.GetStepLookupFunction(reverse_method);
backend.SetSequenceSize(10);
EXPECT_EQ(reverse(0, 0, -1), -1);
EXPECT_EQ(reverse(0, 0, 0), 9);
EXPECT_EQ(reverse(1, 1, 1), 8);
EXPECT_EQ(reverse(8, 8, 8), 1);
EXPECT_EQ(reverse(9, 9, 9), 0);
EXPECT_EQ(reverse(10, 10, 10), -1);
EXPECT_EQ(reverse(-1, -1, 5), 4);
EXPECT_EQ(reverse(0, 0, 9999), -1);
EXPECT_EQ(reverse(0, 0, -9999), -1);
backend.SetSequenceSize(11);
EXPECT_EQ(reverse(0, 0, -1), -1);
EXPECT_EQ(reverse(0, 0, 0), 10);
EXPECT_EQ(reverse(1, 1, 1), 9);
EXPECT_EQ(reverse(8, 8, 8), 2);
EXPECT_EQ(reverse(9, 9, 9), 1);
EXPECT_EQ(reverse(10, 10, 10), 0);
EXPECT_EQ(reverse(-1, -1, 5), 5);
EXPECT_EQ(reverse(0, 0, 9999), -1);
EXPECT_EQ(reverse(0, 0, -9999), -1);
}
}
// Tests that the input beam is forwarded.
TEST(SequenceBackendTest, BeamForwarding) {
SequenceBackend backend;
const TransitionState *parent_state = nullptr;
parent_state += 1234; // arbitrary non-null pointer
const std::vector<std::vector<const TransitionState *>> parent_states = {
{parent_state}};
const int ignored_max_beam_size = 999;
InputBatchCache *ignored_input = nullptr;
backend.InitializeData(parent_states, ignored_max_beam_size, ignored_input);
EXPECT_EQ(backend.GetBeam(), parent_states);
}
// Tests the accessors of the backend.
TEST(SequenceBackendTest, Accessors) {
SequenceBackend backend;
ComponentSpec spec;
spec.set_name("foo");
backend.InitializeComponent(spec);
EXPECT_EQ(backend.Name(), "foo");
EXPECT_EQ(backend.BeamSize(), 1);
EXPECT_EQ(backend.BatchSize(), 1);
EXPECT_TRUE(backend.IsReady());
EXPECT_TRUE(backend.IsTerminal());
}
// Tests the trivial mutators of the backend.
TEST(SequenceBackendTest, Mutators) {
SequenceBackend backend;
// These are NOPs and should not crash.
backend.FinalizeData();
backend.ResetComponent();
backend.InitializeTracing();
backend.DisableTracing();
}
// Tests the beam index accessors of the backend.
TEST(SequenceBackendTest, BeamIndex) {
SequenceBackend backend;
// This always returns the current_index (first arg).
EXPECT_EQ(backend.GetSourceBeamIndex(0, 0), 0);
EXPECT_EQ(backend.GetSourceBeamIndex(1, 2), 1);
EXPECT_EQ(backend.GetSourceBeamIndex(-1, -1), -1);
EXPECT_EQ(backend.GetSourceBeamIndex(10, 99), 10);
// This always returns 0.
EXPECT_EQ(backend.GetBeamIndexAtStep(0, 0, 0), 0);
EXPECT_EQ(backend.GetBeamIndexAtStep(1, 2, 3), 0);
EXPECT_EQ(backend.GetBeamIndexAtStep(-1, -1, -1), 0);
EXPECT_EQ(backend.GetBeamIndexAtStep(123, 456, 789), 0);
}
// Tests the that the backend produces a single empty trace.
TEST(SequenceBackendTest, Tracing) {
SequenceBackend backend;
const ComponentTrace empty_trace;
const auto actual_traces = backend.GetTraceProtos();
ASSERT_EQ(actual_traces.size(), 1);
ASSERT_EQ(actual_traces[0].size(), 1);
EXPECT_THAT(actual_traces[0][0], test::EqualsProto(empty_trace));
}
// Tests the unsupported methods of the backend.
TEST(SequenceBackendTest, UnsupportedMethods) {
SequenceBackend backend;
EXPECT_DEATH(backend.StepsTaken(0), "Not supported");
EXPECT_DEATH(backend.AdvanceFromPrediction(nullptr, 0, 0), "Not supported");
EXPECT_DEATH(backend.AdvanceFromOracle(), "Not supported");
EXPECT_DEATH(backend.GetOracleLabels(), "Not supported");
EXPECT_DEATH(backend.GetFixedFeatures(nullptr, nullptr, nullptr, 0),
"Not supported");
EXPECT_DEATH(backend.BulkGetFixedFeatures(
BulkFeatureExtractor(nullptr, nullptr, nullptr)),
"Not supported");
EXPECT_DEATH(backend.BulkEmbedFixedFeatures(0, 0, 0, {}, nullptr),
"Not supported");
EXPECT_DEATH(backend.GetRawLinkFeatures(0), "Not supported");
EXPECT_DEATH(backend.AddTranslatedLinkFeaturesToTrace({}, 0),
"Not supported");
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string.h>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_model.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Sequence-based bulk version of DynamicComponent.
class SequenceBulkDynamicComponent : public Component {
public:
// Implements Component.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override;
bool Supports(const ComponentSpec &component_spec,
const string &normalized_builder_name) const override;
bool PreferredTo(const Component &other) const override { return false; }
private:
// Evaluates all input features in the |state|, concatenates them into a
// matrix of inputs in the |network_states|, and returns the matrix.
Matrix<float> EvaluateInputs(const SequenceModel::EvaluateState &state,
const NetworkStates &network_states) const;
// Managers for input embeddings.
FixedEmbeddingManager fixed_embedding_manager_;
LinkedEmbeddingManager linked_embedding_manager_;
// Sequence-based model evaluator.
SequenceModel sequence_model_;
// Network unit for bulk inference.
std::unique_ptr<BulkNetworkUnit> bulk_network_unit_;
// Concatenated input matrix.
LocalMatrixHandle<float> inputs_handle_;
// Intermediate values used by sequence models.
SharedExtensionHandle<SequenceModel::EvaluateState> evaluate_state_handle_;
};
bool SequenceBulkDynamicComponent::Supports(
const ComponentSpec &component_spec,
const string &normalized_builder_name) const {
// Require embedded fixed features.
for (const FixedFeatureChannel &channel : component_spec.fixed_feature()) {
if (channel.embedding_dim() < 0) return false;
}
// Require non-transformed and non-recurrent linked features.
// TODO(googleuser): Make SequenceLinks support transformed linked features?
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.embedding_dim() >= 0) return false;
if (channel.source_component() == component_spec.name()) return false;
}
return normalized_builder_name == "SequenceBulkDynamicComponent" &&
SequenceModel::Supports(component_spec);
}
// Returns the sum of the dimensions of all channels in the |manager|.
template <class EmbeddingManager>
size_t SumEmbeddingDimensions(const EmbeddingManager &manager) {
size_t sum = 0;
for (size_t i = 0; i < manager.num_channels(); ++i) {
sum += manager.embedding_dim(i);
}
return sum;
}
tensorflow::Status SequenceBulkDynamicComponent::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
TF_RETURN_IF_ERROR(BulkNetworkUnit::CreateOrError(
BulkNetworkUnit::GetClassName(component_spec), &bulk_network_unit_));
TF_RETURN_IF_ERROR(
bulk_network_unit_->Initialize(component_spec, variable_store,
network_state_manager, extension_manager));
TF_RETURN_IF_ERROR(fixed_embedding_manager_.Reset(
component_spec, variable_store, network_state_manager));
TF_RETURN_IF_ERROR(linked_embedding_manager_.Reset(
component_spec, variable_store, network_state_manager));
const size_t concatenated_input_dim =
SumEmbeddingDimensions(fixed_embedding_manager_) +
SumEmbeddingDimensions(linked_embedding_manager_);
TF_RETURN_IF_ERROR(
bulk_network_unit_->ValidateInputDimension(concatenated_input_dim));
TF_RETURN_IF_ERROR(
network_state_manager->AddLocal(concatenated_input_dim, &inputs_handle_));
TF_RETURN_IF_ERROR(sequence_model_.Initialize(
component_spec, bulk_network_unit_->GetLogitsName(),
&fixed_embedding_manager_, &linked_embedding_manager_,
network_state_manager));
extension_manager->GetShared(&evaluate_state_handle_);
return tensorflow::Status::OK();
}
tensorflow::Status SequenceBulkDynamicComponent::Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const {
const NetworkStates &network_states = session_state->network_states;
SequenceModel::EvaluateState &state =
session_state->extensions.Get(evaluate_state_handle_);
TF_RETURN_IF_ERROR(
sequence_model_.Preprocess(session_state, compute_session, &state));
const Matrix<float> inputs = EvaluateInputs(state, network_states);
TF_RETURN_IF_ERROR(bulk_network_unit_->Evaluate(inputs, session_state));
return sequence_model_.Predict(network_states, &state);
}
Matrix<float> SequenceBulkDynamicComponent::EvaluateInputs(
const SequenceModel::EvaluateState &state,
const NetworkStates &network_states) const {
const MutableMatrix<float> inputs = network_states.GetLocal(inputs_handle_);
// Declared here for reuse in the loop below.
bool is_out_of_bounds = false;
Vector<float> embedding;
// Handle forward and reverse iteration via a start index and increment.
int target_index = sequence_model_.left_to_right() ? 0 : state.num_steps - 1;
const int target_increment = sequence_model_.left_to_right() ? 1 : -1;
for (size_t step_index = 0; step_index < state.num_steps;
++step_index, target_index += target_increment) {
const MutableVector<float> row = inputs.row(step_index);
float *output = row.data();
for (size_t channel_id = 0; channel_id < state.features.num_channels();
++channel_id) {
embedding = state.features.GetEmbedding(channel_id, target_index);
memcpy(output, embedding.data(), embedding.size() * sizeof(float));
output += embedding.size();
}
for (size_t channel_id = 0; channel_id < state.links.num_channels();
++channel_id) {
state.links.Get(channel_id, target_index, &embedding, &is_out_of_bounds);
memcpy(output, embedding.data(), embedding.size() * sizeof(float));
output += embedding.size();
}
DCHECK_EQ(output, row.end());
}
return Matrix<float>(inputs);
}
DRAGNN_RUNTIME_REGISTER_COMPONENT(SequenceBulkDynamicComponent);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::Return;
constexpr size_t kNumSteps = 50;
constexpr size_t kFixedDim = 11;
constexpr size_t kFixedVocabularySize = 123;
constexpr float kFixedValue = 0.5;
constexpr size_t kLinkedDim = 13;
constexpr float kLinkedValue = 1.25;
constexpr char kPreviousComponentName[] = "previous_component";
constexpr char kPreviousLayerName[] = "previous_layer";
constexpr char kLogitsName[] = "logits";
constexpr size_t kLogitsDim = kFixedDim + kLinkedDim;
// Adds one to all inputs.
class BulkAddOne : public BulkNetworkUnit {
public:
// Implements BulkNetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return network_state_manager->AddLayer(kLogitsName, kLogitsDim,
&logits_handle_);
}
tensorflow::Status ValidateInputDimension(size_t dimension) const override {
return tensorflow::Status::OK();
}
string GetLogitsName() const override { return kLogitsName; }
tensorflow::Status Evaluate(Matrix<float> inputs,
SessionState *session_state) const override {
const MutableMatrix<float> logits =
session_state->network_states.GetLayer(logits_handle_);
for (size_t row = 0; row < inputs.num_rows(); ++row) {
for (size_t column = 0; column < inputs.num_columns(); ++column) {
logits.row(row)[column] = inputs.row(row)[column] + 1.0;
}
}
return tensorflow::Status::OK();
}
private:
// Output logits.
LayerHandle<float> logits_handle_;
};
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT(BulkAddOne);
// A component that also prefers other but is triggered on the presence of a
// resource. This can be used to cause a component selection conflict.
class ImTheWorst : public Component {
public:
// Implements Component.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return tensorflow::Status::OK();
}
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override {
return tensorflow::Status::OK();
}
bool Supports(const ComponentSpec &component_spec,
const string &normalized_builder_name) const override {
return component_spec.resource_size() > 0;
}
bool PreferredTo(const Component &other) const override { return false; }
};
DRAGNN_RUNTIME_REGISTER_COMPONENT(ImTheWorst);
// Extractor that produces a sequence of zeros.
class ExtractZeros : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &) const override {
return true;
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *ids) const override {
ids->assign(kNumSteps, 0);
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(ExtractZeros);
// Linker that produces a sequence of zeros.
class LinkZeros : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &,
const ComponentSpec &) const override {
return true;
}
tensorflow::Status Initialize(const LinkedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetLinks(size_t, InputBatchCache *,
std::vector<int32> *links) const override {
links->assign(kNumSteps, 0);
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(LinkZeros);
// Predictor that captures the logits.
class CaptureLogits : public SequencePredictor {
public:
// Implements SequencePredictor.
bool Supports(const ComponentSpec &) const override { return true; }
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status Predict(Matrix<float> logits,
InputBatchCache *) const override {
logits_ = logits;
return tensorflow::Status::OK();
}
// Returns the captured logits.
static Matrix<float> GetCapturedLogits() { return logits_; }
private:
// Logits from the most recent call to Predict().
static Matrix<float> logits_;
};
Matrix<float> CaptureLogits::logits_;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(CaptureLogits);
class SequenceBulkDynamicComponentTest : public NetworkTestBase {
protected:
SequenceBulkDynamicComponentTest() {
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input_));
EXPECT_CALL(compute_session_, GetReadiedComponent(kTestComponentName))
.WillRepeatedly(Return(&backend_));
}
// Returns a spec that the network supports.
ComponentSpec GetSupportedSpec() {
ComponentSpec component_spec;
component_spec.set_name(kTestComponentName);
component_spec.set_num_actions(kLogitsDim);
component_spec.mutable_network_unit()->set_registered_name("AddOne");
component_spec.mutable_backend()->set_registered_name("SequenceBackend");
component_spec.mutable_component_builder()->set_registered_name(
"SequenceBulkDynamicComponent");
auto &component_parameters =
*component_spec.mutable_component_builder()->mutable_parameters();
component_parameters["sequence_extractors"] = "ExtractZeros";
component_parameters["sequence_linkers"] = "LinkZeros";
component_parameters["sequence_predictor"] = "CaptureLogits";
FixedFeatureChannel *fixed_feature = component_spec.add_fixed_feature();
fixed_feature->set_size(1);
fixed_feature->set_embedding_dim(kFixedDim);
fixed_feature->set_vocabulary_size(kFixedVocabularySize);
LinkedFeatureChannel *linked_feature = component_spec.add_linked_feature();
linked_feature->set_size(1);
linked_feature->set_embedding_dim(-1);
linked_feature->set_source_component(kPreviousComponentName);
linked_feature->set_source_layer(kPreviousLayerName);
return component_spec;
}
// Creates a network unit, initializes it based on the |component_spec_text|,
// and evaluates it. On error, returns non-OK.
tensorflow::Status Run(const ComponentSpec &component_spec) {
AddComponent(kPreviousComponentName);
AddLayer(kPreviousLayerName, kLinkedDim);
AddComponent(kTestComponentName);
AddFixedEmbeddingMatrix(0, kFixedVocabularySize, kFixedDim, kFixedValue);
std::unique_ptr<Component> component;
TF_RETURN_IF_ERROR(
Component::CreateOrError("SequenceBulkDynamicComponent", &component));
TF_RETURN_IF_ERROR(component->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
// Allocates network states for a few steps.
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
FillLayer(kPreviousComponentName, kPreviousLayerName, kLinkedValue);
StartComponent(0);
session_state_.extensions.Reset(&extension_manager_);
return component->Evaluate(&session_state_, &compute_session_, nullptr);
}
// Input batch injected into Evaluate() by default.
InputBatchCache input_;
// Backend injected into Evaluate().
SequenceBackend backend_;
};
// Tests that the supported spec is supported.
TEST_F(SequenceBulkDynamicComponentTest, Supported) {
const ComponentSpec component_spec = GetSupportedSpec();
string component_type;
TF_ASSERT_OK(Component::Select(component_spec, &component_type));
EXPECT_EQ(component_type, "SequenceBulkDynamicComponent");
TF_ASSERT_OK(Run(component_spec));
const Matrix<float> logits = CaptureLogits::GetCapturedLogits();
ASSERT_EQ(logits.num_rows(), kNumSteps);
ASSERT_EQ(logits.num_columns(), kFixedDim + kLinkedDim);
for (size_t row = 0; row < kNumSteps; ++row) {
size_t column = 0;
for (; column < kFixedDim; ++column) {
EXPECT_EQ(logits.row(row)[column], kFixedValue + 1.0);
}
for (; column < kFixedDim + kLinkedDim; ++column) {
EXPECT_EQ(logits.row(row)[column], kLinkedValue + 1.0);
}
}
}
// Tests that links cannot be recurrent.
TEST_F(SequenceBulkDynamicComponentTest, ForbidRecurrences) {
ComponentSpec component_spec = GetSupportedSpec();
component_spec.mutable_linked_feature(0)->set_source_component(
kTestComponentName);
string component_type;
EXPECT_THAT(
Component::Select(component_spec, &component_type),
test::IsErrorWithSubstr("Could not find a best spec for component"));
}
// Tests that the component prefers others.
TEST_F(SequenceBulkDynamicComponentTest, PrefersOthers) {
ComponentSpec component_spec = GetSupportedSpec();
component_spec.add_resource();
// Adding a resource triggers the ImTheWorst component, which also prefers
// itself and leads to a selection conflict.
string component_type;
EXPECT_THAT(
Component::Select(component_spec, &component_type),
test::IsErrorWithSubstr("both think they should be dis-preferred"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment