Commit edea2b67 authored by Terry Koo's avatar Terry Koo
Browse files

Remove runtime because reasons.

parent a4bb31d0
component {
name: "rnn"
transition_system {
registered_name: "shift-only"
parameters {
key: "left_to_right"
value: "false"
}
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "words-embedding-input"
part {
file_format: "tf-records"
record_format: "syntaxnet.TokenEmbedding"
}
}
resource {
name: "words-vocab-input"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "char-ngram-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "word-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "myelin-flow"
part {
file_format: "model"
record_format: "sling.myelin.Flow"
}
}
fixed_feature {
name: "char_ngrams"
fml: "input.token { offset(-1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(0).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) }"
embedding_dim: -1
vocabulary_size: 25788
size: 3
}
fixed_feature {
name: "words"
fml: "input.token.word(min-freq=2)"
embedding_dim: -1
vocabulary_size: 23769
size: 1
}
network_unit {
registered_name: "LSTMNetwork"
parameters {
key: "hidden_layer_sizes"
value: "128"
}
parameters {
key: "omit_logits"
value: "true"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 1
attention_component: ""
component_builder {
registered_name: "MyelinDynamicComponent"
}
}
component {
name: "tagger"
transition_system {
registered_name: "tagger"
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "tag-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "tag-to-category"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "myelin-flow"
part {
file_format: "model"
record_format: "sling.myelin.Flow"
}
}
linked_feature {
name: "recurrence"
fml: "bias(0)"
embedding_dim: -1
size: 1
source_component: "tagger"
source_translator: "history"
source_layer: "layer_0"
}
linked_feature {
name: "rnn"
fml: "input.focus"
embedding_dim: -1
size: 1
source_component: "rnn"
source_translator: "reverse-token"
source_layer: "layer_0"
}
network_unit {
registered_name: "FeedForwardNetwork"
parameters {
key: "hidden_layer_sizes"
value: "64,64"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 45
attention_component: ""
component_builder {
registered_name: "MyelinDynamicComponent"
}
}
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_states.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns the first value in |container| whose ".name" field is |name|, or null
// if not found.
template <class Container>
const typename Container::value_type *Find(const Container &container,
const string &name) {
for (auto &value : container) {
if (value.name == name) return &value;
}
return nullptr;
}
} // namespace
tensorflow::Status NetworkStateManager::AddComponent(const string &name) {
if (Find(components_, name) != nullptr) {
return tensorflow::errors::FailedPrecondition("Component '", name,
"' already exists");
}
// Success; make modifications.
components_.emplace_back(name);
return tensorflow::Status::OK();
}
tensorflow::Status NetworkStateManager::AddLayerImpl(
const string &name, std::type_index type, bool is_pairwise, size_t bytes,
size_t *component_index, OperandHandle *operand_handle) {
if (components_.empty()) {
return tensorflow::errors::FailedPrecondition("No current component");
}
ComponentConfig &component = components_.back();
if (Find(component.layers, name) != nullptr) {
return tensorflow::errors::FailedPrecondition(
"Layer '", name, "' already exists in component '", component.name,
"'");
}
if (component.aliases.find(name) != component.aliases.end()) {
return tensorflow::errors::FailedPrecondition(
"Layer '", name, "' conflicts with an existing alias in component '",
component.name, "'");
}
// Success; make modifications.
const OperandType operand_type =
is_pairwise ? OperandType::kPairwise : OperandType::kStepwise;
*component_index = components_.size() - 1;
*operand_handle = component.manager.Add({operand_type, bytes});
component.layers.emplace_back(name, type, *operand_handle);
return tensorflow::Status::OK();
}
tensorflow::Status NetworkStateManager::AddLayerAlias(const string &alias,
const string &name) {
if (components_.empty()) {
return tensorflow::errors::FailedPrecondition("No current component");
}
ComponentConfig &component = components_.back();
if (Find(component.layers, name) == nullptr) {
return tensorflow::errors::FailedPrecondition(
"Target layer '", name, "' of alias '", alias,
"' does not exist in component '", component.name, "'");
}
if (Find(component.layers, alias) != nullptr) {
return tensorflow::errors::FailedPrecondition(
"Alias '", alias, "' conflicts with an existing layer in component '",
component.name, "'");
}
if (component.aliases.find(alias) != component.aliases.end()) {
return tensorflow::errors::FailedPrecondition(
"Alias '", alias, "' already exists in component '", component.name,
"'");
}
// Success; make modifications.
component.aliases[alias] = name;
return tensorflow::Status::OK();
}
tensorflow::Status NetworkStateManager::AddLocalImpl(const OperandSpec &spec,
OperandHandle *handle) {
if (components_.empty()) {
return tensorflow::errors::FailedPrecondition("No current component");
}
ComponentConfig &component = components_.back();
// Success; make modifications.
*handle = component.manager.Add(spec);
return tensorflow::Status::OK();
}
tensorflow::Status NetworkStateManager::LookupLayerImpl(
const string &component_name, const string &layer_name_or_alias,
std::type_index type, bool is_pairwise, size_t *bytes,
size_t *component_index, OperandHandle *operand_handle) const {
const ComponentConfig *component = Find(components_, component_name);
if (component == nullptr) {
return tensorflow::errors::FailedPrecondition("Unknown component '",
component_name, "'");
}
// If necessary, resolve a layer alias into a layer name. Note that aliases
// are non-transitive, since AddLayerAlias() requires that the target of the
// alias is a layer.
const auto it = component->aliases.find(layer_name_or_alias);
const string &layer_name =
it != component->aliases.end() ? it->second : layer_name_or_alias;
const LayerConfig *layer = Find(component->layers, layer_name);
if (layer == nullptr) {
return tensorflow::errors::FailedPrecondition(
"Unknown layer '", layer_name, "' in component '", component_name, "'");
}
if (layer->type != type) {
return tensorflow::errors::InvalidArgument(
"Layer '", layer_name, "' in component '", component_name,
"' does not match its expected type");
}
const OperandType required_type =
is_pairwise ? OperandType::kPairwise : OperandType::kStepwise;
const OperandSpec &operand_spec = component->manager.spec(layer->handle);
if (operand_spec.type != required_type) {
return tensorflow::errors::InvalidArgument(
"Layer '", layer_name, "' in component '", component_name,
"' does not match its expected OperandType");
}
// Success; make modifications.
*bytes = operand_spec.size;
*component_index = component - components_.data();
*operand_handle = layer->handle;
return tensorflow::Status::OK();
}
void NetworkStates::Reset(const NetworkStateManager *manager) {
manager_ = manager;
num_active_components_ = 0;
// Never shrink the |component_operands_|, to avoid deallocating (and then
// eventually reallocating) operand arrays.
if (manager_->components_.size() > component_operands_.size()) {
component_operands_.resize(manager_->components_.size());
}
}
tensorflow::Status NetworkStates::StartNextComponent(
size_t pre_allocate_num_steps) {
if (manager_ == nullptr) {
return tensorflow::errors::FailedPrecondition("No manager");
}
if (num_active_components_ >= manager_->components_.size()) {
return tensorflow::errors::OutOfRange("No next component");
}
// Success; make modifications.
const OperandManager *operand_manager =
&manager_->components_[num_active_components_].manager;
component_operands_[num_active_components_].Reset(operand_manager,
pre_allocate_num_steps);
++num_active_components_;
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for declaring, allocating, and retrieving network states, similar to
// the "NetworkState" class and the "network_states" argument to the build_*()
// methods of ComponentBuilderBase; see component.py.
//
// In brief, a DRAGNN network consists of a sequence of named components, each
// of which produces a set of named output layers. Each component can access
// its own layers as well as those of preceding components. Components can also
// access "local operands", which are like layers but private to that particular
// component. Local operands can be useful for, e.g., caching an intermediate
// result in a complex computation.
//
// For example, suppose a network has two components: "tagger" and "parser",
// where the parser uses the hidden activations of the tagger. In this case,
// the tagger can add a layer called "hidden" at init time and fill that layer
// at processing time. Corespondingly, the parser can look for a layer called
// "hidden" in the "tagger" component at init time, and read the activations at
// processing time. (Note that for convenience, such links should be handled
// using the utils in linked_embeddings.h).
//
// As another example, suppose we are implementing an LSTM and we wish to keep
// the cell state private. In this case, the LSTM component could add a layer
// for exporting the hidden activations and a local matrix for the sequence of
// cell states. A more compact approach is to use two local vectors instead,
// one for even steps and the other for odd steps.
#ifndef DRAGNN_RUNTIME_NETWORK_STATES_H_
#define DRAGNN_RUNTIME_NETWORK_STATES_H_
#include <stddef.h>
#include <stdint.h>
#include <map>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <vector>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/operands.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Opaque handles used to access typed layers or local operands.
template <class T>
class LayerHandle;
template <class T>
class PairwiseLayerHandle;
template <class T>
class LocalVectorHandle;
template <class T>
class LocalMatrixHandle;
// A class that manages the state of a DRAGNN network and associates each layer
// and local operand with a handle. Layer and local operand contents can be
// retrieved using these handles; see NetworkStates below.
class NetworkStateManager {
public:
// Creates an empty manager.
NetworkStateManager() = default;
// Adds a component named |name| and makes it the current component. The
// |name| must be unique in the network. Components are sequenced in the
// order they are added. On error, returns non-OK and modifies nothing.
tensorflow::Status AddComponent(const string &name);
// Adds a layer named |name| to the current component and sets |handle| to its
// handle. The |name| must be unique in the current component. The layer is
// realized as a Matrix<T> with one row per step and |dimension| columns. On
// error, returns non-OK and modifies nothing.
template <class T>
tensorflow::Status AddLayer(const string &name, size_t dimension,
LayerHandle<T> *handle);
// As above, but for pairwise layers.
template <class T>
tensorflow::Status AddLayer(const string &name, size_t dimension,
PairwiseLayerHandle<T> *handle);
// As above, but for a local Vector<T> or Matrix<T> operand. The operand is
// "local" in the sense that only the caller knows its handle.
template <class T>
tensorflow::Status AddLocal(size_t dimension, LocalVectorHandle<T> *handle);
template <class T>
tensorflow::Status AddLocal(size_t dimension, LocalMatrixHandle<T> *handle);
// Makes |alias| an alias of the layer named |name| in the current component,
// so that lookups of |alias| resolve to |name|. The |name| must already
// exist as a layer, and layer names and aliases must be unique within each
// component. On error, returns non-OK and modifies nothing.
tensorflow::Status AddLayerAlias(const string &alias, const string &name);
// Finds the layer that matches |layer_name_or_alias| in the component named
// |component_name|. Sets |dimension| to its dimension and |handle| to its
// handle. On error, returns non-OK and modifies nothing.
template <class T>
tensorflow::Status LookupLayer(const string &component_name,
const string &layer_name_or_alias,
size_t *dimension,
LayerHandle<T> *handle) const;
// As above, but for pairwise layers.
template <class T>
tensorflow::Status LookupLayer(const string &component_name,
const string &layer_name_or_alias,
size_t *dimension,
PairwiseLayerHandle<T> *handle) const;
private:
friend class NetworkStates;
// Configuration information for a layer.
struct LayerConfig {
// Creates a config for a layer with the |name|, |type| ID, and |handle|.
LayerConfig(const string &name, std::type_index type, OperandHandle handle)
: name(name), type(type), handle(handle) {}
// Name of the layer.
string name;
// Type ID of the layer contents.
std::type_index type;
// Handle of the operand that holds the layer contents.
OperandHandle handle;
};
// Configuration information for a component.
struct ComponentConfig {
// Creates an empty config for a component with the |name|.
explicit ComponentConfig(const string &name) : name(name) {}
// Name of the component.
string name;
// Manager for the operands used by the component.
OperandManager manager;
// Configuration of each layer produced by the component.
std::vector<LayerConfig> layers;
// Mapping from layer alias to layer name in the component.
std::map<string, string> aliases;
};
// Implements the non-templated part of AddLayer(). Adds a layer with the
// |name|, |type| ID, and size in |bytes|. Sets the |component_index| and
// |operand_handle| according to the containing component and operand. If
// |is_pairwise| is true, then the new layer is pairwise (vs stepwise). On
// error, returns non-OK and modifies nothing.
tensorflow::Status AddLayerImpl(const string &name, std::type_index type,
bool is_pairwise, size_t bytes,
size_t *component_index,
OperandHandle *operand_handle);
// Implements the non-templated portion of AddLocal*(). Adds a local operand
// with the |spec| and sets |handle| to its handle. On error, returns non-OK
// and modifies nothing.
tensorflow::Status AddLocalImpl(const OperandSpec &spec,
OperandHandle *handle);
// Implements the non-templated portion of LookupLayer(). Finds the layer
// that matches the |component_name| and |layer_name_or_alias|. That layer
// must match the |type| ID. Sets |bytes| to its size, |component_index| to
// the index of its containing component, and |operand_handle| to the handle
// of its underlying operand. If |is_pairwise| is true, then the layer must
// be pairwise (vs stepwise). On error, returns non-OK and modifies nothing.
tensorflow::Status LookupLayerImpl(const string &component_name,
const string &layer_name_or_alias,
std::type_index type, bool is_pairwise,
size_t *bytes, size_t *component_index,
OperandHandle *operand_handle) const;
// Ordered list of configurations for the components in the network.
std::vector<ComponentConfig> components_;
};
// A set of network states. The structure of the network is configured by a
// NetworkStateManager, and layer and local operand contents can be accessed
// using the handles produced by the manager.
//
// Multiple NetworkStates instances can share the same NetworkStateManager. In
// addition, a NetworkStates instance can be reused by repeatedly Reset()-ing
// it, potentially with different NetworkStateManagers. Such reuse can reduce
// allocation overhead.
class NetworkStates {
public:
// Creates an uninitialized set of states.
NetworkStates() = default;
// Resets this to an empty set configured by the |manager|. The |manager|
// must live until this is destroyed or Reset(), and should not be modified
// during that time. No current component is set; call StartNextComponent()
// to start the first component.
void Reset(const NetworkStateManager *manager);
// Starts the next component and makes it the current component. Initially,
// the component has zero steps but more can be added using AddStep(). Uses
// |pre_allocate_num_steps| to pre-allocate storage; see Operands::Reset().
// On error, returns non-OK and modifies nothing.
tensorflow::Status StartNextComponent(size_t pre_allocate_num_steps);
// Adds one or more steps to the current component. Invalidates all
// previously-returned matrices of the current component.
void AddStep() { AddSteps(1); }
void AddSteps(size_t num_steps);
// Returns the layer associated with the |handle|.
template <class T>
MutableMatrix<T> GetLayer(LayerHandle<T> handle) const;
// Returns the pairwise layer associated with the |handle|.
template <class T>
MutableMatrix<T> GetLayer(PairwiseLayerHandle<T> handle) const;
// Returns the local vector or matrix associated with the |handle| in the
// current component.
template <class T>
MutableVector<T> GetLocal(LocalVectorHandle<T> handle) const;
template <class T>
MutableMatrix<T> GetLocal(LocalMatrixHandle<T> handle) const;
private:
// Manager of this set of network states.
const NetworkStateManager *manager_ = nullptr;
// Number of active components in the |component_operands_|.
size_t num_active_components_ = 0;
// Ordered list of per-component operands. Only the first
// |num_active_components_| entries are valid.
std::vector<Operands> component_operands_;
};
// Implementation details below.
// An opaque handle to a typed layer of some component.
template <class T>
class LayerHandle {
public:
static_assert(IsAlignable<T>(), "T must be alignable");
// Creates an invalid handle.
LayerHandle() = default;
private:
friend class NetworkStateManager;
friend class NetworkStates;
// Index of the containing component in the network state manager.
size_t component_index_ = SIZE_MAX;
// Handle of the operand holding the layer.
OperandHandle operand_handle_;
};
// An opaque handle to a typed pairwise layer of some component.
template <class T>
class PairwiseLayerHandle {
public:
static_assert(IsAlignable<T>(), "T must be alignable");
// Creates an invalid handle.
PairwiseLayerHandle() = default;
private:
friend class NetworkStateManager;
friend class NetworkStates;
// Index of the containing component in the network state manager.
size_t component_index_ = SIZE_MAX;
// Handle of the operand holding the layer.
OperandHandle operand_handle_;
};
// An opaque handle to a typed local operand of some component.
template <class T>
class LocalVectorHandle {
public:
static_assert(IsAlignable<T>(), "T must be alignable");
// Creates an invalid handle.
LocalVectorHandle() = default;
private:
friend class NetworkStateManager;
friend class NetworkStates;
// Handle of the local operand.
OperandHandle operand_handle_;
};
// An opaque handle to a typed local operand of some component.
template <class T>
class LocalMatrixHandle {
public:
static_assert(IsAlignable<T>(), "T must be alignable");
// Creates an invalid handle.
LocalMatrixHandle() = default;
private:
friend class NetworkStateManager;
friend class NetworkStates;
// Handle of the local operand.
OperandHandle operand_handle_;
};
template <class T>
tensorflow::Status NetworkStateManager::AddLayer(const string &name,
size_t dimension,
LayerHandle<T> *handle) {
return AddLayerImpl(name, std::type_index(typeid(T)), /*is_pairwise=*/false,
dimension * sizeof(T), &handle->component_index_,
&handle->operand_handle_);
}
template <class T>
tensorflow::Status NetworkStateManager::AddLayer(
const string &name, size_t dimension, PairwiseLayerHandle<T> *handle) {
return AddLayerImpl(name, std::type_index(typeid(T)), /*is_pairwise=*/true,
dimension * sizeof(T), &handle->component_index_,
&handle->operand_handle_);
}
template <class T>
tensorflow::Status NetworkStateManager::AddLocal(size_t dimension,
LocalVectorHandle<T> *handle) {
return AddLocalImpl({OperandType::kSingular, dimension * sizeof(T)},
&handle->operand_handle_);
}
template <class T>
tensorflow::Status NetworkStateManager::AddLocal(size_t dimension,
LocalMatrixHandle<T> *handle) {
return AddLocalImpl({OperandType::kStepwise, dimension * sizeof(T)},
&handle->operand_handle_);
}
template <class T>
tensorflow::Status NetworkStateManager::LookupLayer(
const string &component_name, const string &layer_name_or_alias,
size_t *dimension, LayerHandle<T> *handle) const {
TF_RETURN_IF_ERROR(LookupLayerImpl(
component_name, layer_name_or_alias, std::type_index(typeid(T)),
/*is_pairwise=*/false, dimension, &handle->component_index_,
&handle->operand_handle_));
DCHECK_EQ(*dimension % sizeof(T), 0);
*dimension /= sizeof(T); // bytes => Ts
return tensorflow::Status::OK();
}
template <class T>
tensorflow::Status NetworkStateManager::LookupLayer(
const string &component_name, const string &layer_name_or_alias,
size_t *dimension, PairwiseLayerHandle<T> *handle) const {
TF_RETURN_IF_ERROR(LookupLayerImpl(
component_name, layer_name_or_alias, std::type_index(typeid(T)),
/*is_pairwise=*/true, dimension, &handle->component_index_,
&handle->operand_handle_));
DCHECK_EQ(*dimension % sizeof(T), 0);
*dimension /= sizeof(T); // bytes => Ts
return tensorflow::Status::OK();
}
inline void NetworkStates::AddSteps(size_t num_steps) {
component_operands_[num_active_components_ - 1].AddSteps(num_steps);
}
template <class T>
MutableMatrix<T> NetworkStates::GetLayer(LayerHandle<T> handle) const {
return MutableMatrix<T>(
component_operands_[handle.component_index_].GetStepwise(
handle.operand_handle_));
}
template <class T>
MutableMatrix<T> NetworkStates::GetLayer(PairwiseLayerHandle<T> handle) const {
return MutableMatrix<T>(
component_operands_[handle.component_index_].GetPairwise(
handle.operand_handle_));
}
template <class T>
MutableVector<T> NetworkStates::GetLocal(LocalVectorHandle<T> handle) const {
return MutableVector<T>(
component_operands_[num_active_components_ - 1].GetSingular(
handle.operand_handle_));
}
template <class T>
MutableMatrix<T> NetworkStates::GetLocal(LocalMatrixHandle<T> handle) const {
return MutableMatrix<T>(
component_operands_[num_active_components_ - 1].GetStepwise(
handle.operand_handle_));
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_NETWORK_STATES_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_states.h"
#include <stddef.h>
#include <string.h>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Expects that two objects have identical bit representations.
template <class T>
void ExpectBitwiseEqual(const T &object1, const T &object2) {
EXPECT_EQ(memcmp(&object1, &object2, sizeof(T)), 0);
}
// Expects that the |matrix| has the given dimensions.
template <class T>
void ExpectDimensions(MutableMatrix<T> matrix, size_t num_rows,
size_t num_columns) {
EXPECT_EQ(matrix.num_rows(), num_rows);
EXPECT_EQ(matrix.num_columns(), num_columns);
}
// Sets the |vector| to |size| copies of the |value|.
template <class T>
void Fill(MutableVector<T> vector, size_t size, T value) {
ASSERT_EQ(vector.size(), size);
for (T &element : vector) element = value;
}
// Expects that the |vector| contains |size| copies of the |expected_value|.
template <class T>
void ExpectFilled(MutableVector<T> vector, size_t size, T expected_value) {
ASSERT_EQ(vector.size(), size);
for (const T element : vector) EXPECT_EQ(element, expected_value);
}
// Tests that NetworkStateManager can add a named component.
TEST(NetworkStateManagerTest, AddComponent) {
NetworkStateManager manager;
TF_EXPECT_OK(manager.AddComponent("foo/bar"));
EXPECT_THAT(manager.AddComponent("foo/bar"),
test::IsErrorWithSubstr("Component 'foo/bar' already exists"));
// Empty component name is weird, but OK.
TF_EXPECT_OK(manager.AddComponent(""));
EXPECT_THAT(manager.AddComponent(""),
test::IsErrorWithSubstr("Component '' already exists"));
}
// Tests that NetworkStateManager can add a named layer to the current
// component.
TEST(NetworkStateManagerTest, AddLayer) {
NetworkStateManager manager;
LayerHandle<float> unused_layer_handle;
EXPECT_THAT(manager.AddLayer("layer", 1, &unused_layer_handle),
test::IsErrorWithSubstr("No current component"));
TF_EXPECT_OK(manager.AddComponent("component"));
TF_EXPECT_OK(manager.AddLayer("layer", 2, &unused_layer_handle));
EXPECT_THAT(manager.AddLayer("layer", 2, &unused_layer_handle),
test::IsErrorWithSubstr(
"Layer 'layer' already exists in component 'component'"));
}
// Tests that NetworkStateManager can add a named pairwise layer to the current
// component.
TEST(NetworkStateManagerTest, AddLayerPairwise) {
NetworkStateManager manager;
PairwiseLayerHandle<float> unused_layer_handle;
EXPECT_THAT(manager.AddLayer("layer", 1, &unused_layer_handle),
test::IsErrorWithSubstr("No current component"));
TF_EXPECT_OK(manager.AddComponent("component"));
TF_EXPECT_OK(manager.AddLayer("layer", 2, &unused_layer_handle));
EXPECT_THAT(manager.AddLayer("layer", 2, &unused_layer_handle),
test::IsErrorWithSubstr(
"Layer 'layer' already exists in component 'component'"));
}
// Tests that NetworkStateManager can add an alias to an existing layer. Also
// tests that layer and alias names are required to be unique.
TEST(NetworkStateManagerTest, AddLayerAlias) {
NetworkStateManager manager;
LayerHandle<float> unused_layer_handle;
EXPECT_THAT(manager.AddLayerAlias("alias", "layer"),
test::IsErrorWithSubstr("No current component"));
TF_EXPECT_OK(manager.AddComponent("component"));
EXPECT_THAT(
manager.AddLayerAlias("alias", "layer"),
test::IsErrorWithSubstr("Target layer 'layer' of alias 'alias' does not "
"exist in component 'component'"));
TF_EXPECT_OK(manager.AddLayer("layer", 2, &unused_layer_handle));
TF_EXPECT_OK(manager.AddLayerAlias("alias", "layer"));
EXPECT_THAT(manager.AddLayerAlias("alias", "layer"),
test::IsErrorWithSubstr(
"Alias 'alias' already exists in component 'component'"));
EXPECT_THAT(
manager.AddLayer("alias", 2, &unused_layer_handle),
test::IsErrorWithSubstr("Layer 'alias' conflicts with an existing alias "
"in component 'component'"));
TF_EXPECT_OK(manager.AddLayer("layer2", 2, &unused_layer_handle));
EXPECT_THAT(
manager.AddLayerAlias("layer2", "layer"),
test::IsErrorWithSubstr("Alias 'layer2' conflicts with an existing layer "
"in component 'component'"));
}
// Tests that NetworkStateManager can add a local matrix or vector to the
// current component.
TEST(NetworkStateManagerTest, AddLocal) {
NetworkStateManager manager;
LocalVectorHandle<float> unused_local_vector_handle;
LocalMatrixHandle<float> unused_local_matrix_handle;
EXPECT_THAT(manager.AddLocal(11, &unused_local_matrix_handle),
test::IsErrorWithSubstr("No current component"));
TF_EXPECT_OK(manager.AddComponent("component"));
TF_EXPECT_OK(manager.AddLocal(22, &unused_local_matrix_handle));
TF_EXPECT_OK(manager.AddLocal(33, &unused_local_vector_handle));
}
// Tests that NetworkStateManager can look up existing layers or aliases, and
// fails on invalid layer or component names and for mismatched types.
TEST(NetworkStateManagerTest, LookupLayer) {
NetworkStateManager manager;
LayerHandle<char> char_handle;
LayerHandle<int16> int16_handle;
LayerHandle<uint16> uint16_handle;
PairwiseLayerHandle<char> pairwise_char_handle;
size_t dimension = 0;
// Add some typed layers and aliases.
TF_ASSERT_OK(manager.AddComponent("foo"));
TF_ASSERT_OK(manager.AddLayer("char", 5, &char_handle));
TF_ASSERT_OK(manager.AddLayer("int16", 7, &int16_handle));
TF_ASSERT_OK(manager.AddLayerAlias("char_alias", "char"));
TF_ASSERT_OK(manager.AddLayerAlias("int16_alias", "int16"));
TF_ASSERT_OK(manager.AddComponent("bar"));
TF_ASSERT_OK(manager.AddLayer("uint16", 11, &uint16_handle));
TF_ASSERT_OK(manager.AddLayer("pairwise_char", 13, &pairwise_char_handle));
TF_ASSERT_OK(manager.AddLayerAlias("uint16_alias", "uint16"));
TF_ASSERT_OK(manager.AddLayerAlias("pairwise_char_alias", "pairwise_char"));
// Try looking up unknown components.
EXPECT_THAT(manager.LookupLayer("missing", "char", &dimension, &char_handle),
test::IsErrorWithSubstr("Unknown component 'missing'"));
EXPECT_THAT(manager.LookupLayer("baz", "float", &dimension, &char_handle),
test::IsErrorWithSubstr("Unknown component 'baz'"));
// Try looking up valid components but unknown layers.
EXPECT_THAT(
manager.LookupLayer("foo", "missing", &dimension, &char_handle),
test::IsErrorWithSubstr("Unknown layer 'missing' in component 'foo'"));
EXPECT_THAT(
manager.LookupLayer("bar", "missing", &dimension, &char_handle),
test::IsErrorWithSubstr("Unknown layer 'missing' in component 'bar'"));
// Try looking up valid components and the names of layers or aliases in the
// other components.
EXPECT_THAT(
manager.LookupLayer("foo", "uint16", &dimension, &uint16_handle),
test::IsErrorWithSubstr("Unknown layer 'uint16' in component 'foo'"));
EXPECT_THAT(
manager.LookupLayer("foo", "uint16_alias", &dimension, &uint16_handle),
test::IsErrorWithSubstr(
"Unknown layer 'uint16_alias' in component 'foo'"));
EXPECT_THAT(
manager.LookupLayer("bar", "char", &dimension, &char_handle),
test::IsErrorWithSubstr("Unknown layer 'char' in component 'bar'"));
EXPECT_THAT(
manager.LookupLayer("bar", "char_alias", &dimension, &char_handle),
test::IsErrorWithSubstr("Unknown layer 'char_alias' in component 'bar'"));
// Look up layers with incorrect types.
EXPECT_THAT(
manager.LookupLayer("foo", "char", &dimension, &int16_handle),
test::IsErrorWithSubstr(
"Layer 'char' in component 'foo' does not match its expected type"));
EXPECT_THAT(
manager.LookupLayer("foo", "char", &dimension, &uint16_handle),
test::IsErrorWithSubstr(
"Layer 'char' in component 'foo' does not match its expected type"));
EXPECT_THAT(
manager.LookupLayer("foo", "char", &dimension, &pairwise_char_handle),
test::IsErrorWithSubstr("Layer 'char' in component 'foo' does not match "
"its expected OperandType"));
EXPECT_THAT(
manager.LookupLayer("foo", "int16", &dimension, &char_handle),
test::IsErrorWithSubstr(
"Layer 'int16' in component 'foo' does not match its expected type"));
EXPECT_THAT(
manager.LookupLayer("foo", "int16", &dimension, &uint16_handle),
test::IsErrorWithSubstr(
"Layer 'int16' in component 'foo' does not match its expected type"));
EXPECT_THAT(
manager.LookupLayer("foo", "int16", &dimension, &pairwise_char_handle),
test::IsErrorWithSubstr(
"Layer 'int16' in component 'foo' does not match its expected type"));
EXPECT_THAT(manager.LookupLayer("bar", "uint16", &dimension, &char_handle),
test::IsErrorWithSubstr("Layer 'uint16' in component 'bar' does "
"not match its expected type"));
EXPECT_THAT(manager.LookupLayer("bar", "uint16", &dimension, &int16_handle),
test::IsErrorWithSubstr("Layer 'uint16' in component 'bar' does "
"not match its expected type"));
EXPECT_THAT(
manager.LookupLayer("bar", "uint16", &dimension, &pairwise_char_handle),
test::IsErrorWithSubstr("Layer 'uint16' in component 'bar' does "
"not match its expected type"));
EXPECT_THAT(
manager.LookupLayer("bar", "pairwise_char", &dimension, &char_handle),
test::IsErrorWithSubstr("Layer 'pairwise_char' in component 'bar' does "
"not match its expected OperandType"));
EXPECT_THAT(
manager.LookupLayer("bar", "pairwise_char", &dimension, &int16_handle),
test::IsErrorWithSubstr("Layer 'pairwise_char' in component 'bar' does "
"not match its expected type"));
EXPECT_THAT(
manager.LookupLayer("bar", "pairwise_char", &dimension, &uint16_handle),
test::IsErrorWithSubstr("Layer 'pairwise_char' in component 'bar' does "
"not match its expected type"));
// Look up layers properly, and check their dimensions. Also verify that the
// looked-up handles are identical to the original handles.
LayerHandle<char> lookup_char_handle;
LayerHandle<int16> lookup_int16_handle;
LayerHandle<uint16> lookup_uint16_handle;
PairwiseLayerHandle<char> lookup_pairwise_char_handle;
TF_EXPECT_OK(
manager.LookupLayer("foo", "char", &dimension, &lookup_char_handle));
EXPECT_EQ(dimension, 5);
ExpectBitwiseEqual(lookup_char_handle, char_handle);
TF_EXPECT_OK(
manager.LookupLayer("foo", "int16", &dimension, &lookup_int16_handle));
EXPECT_EQ(dimension, 7);
ExpectBitwiseEqual(lookup_int16_handle, int16_handle);
TF_EXPECT_OK(
manager.LookupLayer("bar", "uint16", &dimension, &lookup_uint16_handle));
EXPECT_EQ(dimension, 11);
ExpectBitwiseEqual(lookup_uint16_handle, uint16_handle);
TF_EXPECT_OK(manager.LookupLayer("bar", "pairwise_char", &dimension,
&lookup_pairwise_char_handle));
EXPECT_EQ(dimension, 13);
ExpectBitwiseEqual(lookup_pairwise_char_handle, pairwise_char_handle);
}
// Tests that NetworkStates cannot start components without a manager.
TEST(NetworkStatesTest, NoManager) {
NetworkStates network_states;
EXPECT_THAT(network_states.StartNextComponent(10),
test::IsErrorWithSubstr("No manager"));
}
// Tests that NetworkStates cannot start components when the manager is empty.
TEST(NetworkStatesTest, EmptyManager) {
NetworkStateManager empty_manager;
NetworkStates network_states;
network_states.Reset(&empty_manager);
EXPECT_THAT(network_states.StartNextComponent(10),
test::IsErrorWithSubstr("No next component"));
}
// Tests that NetworkStates can start the same number of components as were
// configured in its manager.
TEST(NetworkStatesTest, StartNextComponent) {
NetworkStateManager manager;
TF_EXPECT_OK(manager.AddComponent("foo"));
TF_EXPECT_OK(manager.AddComponent("bar"));
TF_EXPECT_OK(manager.AddComponent("baz"));
NetworkStates network_states;
network_states.Reset(&manager);
TF_EXPECT_OK(network_states.StartNextComponent(10));
TF_EXPECT_OK(network_states.StartNextComponent(11));
TF_EXPECT_OK(network_states.StartNextComponent(12));
EXPECT_THAT(network_states.StartNextComponent(13),
test::IsErrorWithSubstr("No next component"));
}
// Tests that NetworkStates contains layers and locals whose dimensions match
// the configuration of its manager.
TEST(NetworkStatesTest, Dimensions) {
NetworkStateManager manager;
// The "foo" component has two layers and a local vector.
LayerHandle<float> foo_hidden_handle;
LocalVectorHandle<int16> foo_local_handle;
PairwiseLayerHandle<float> foo_logits_handle;
TF_ASSERT_OK(manager.AddComponent("foo"));
TF_ASSERT_OK(manager.AddLayer("hidden", 10, &foo_hidden_handle));
TF_ASSERT_OK(manager.AddLocal(20, &foo_local_handle));
TF_ASSERT_OK(manager.AddLayer("logits", 30, &foo_logits_handle));
// The "bar" component has one layer and a local matrix.
LayerHandle<float> bar_logits_handle;
LocalMatrixHandle<bool> bar_local_handle;
TF_ASSERT_OK(manager.AddComponent("bar"));
TF_ASSERT_OK(manager.AddLayer("logits", 40, &bar_logits_handle));
TF_ASSERT_OK(manager.AddLocal(50, &bar_local_handle));
// Initialize a NetworkStates and check its dimensions. Note that matrices
// start with 0 rows since there are 0 steps.
NetworkStates network_states;
network_states.Reset(&manager);
TF_EXPECT_OK(network_states.StartNextComponent(13));
ExpectDimensions(network_states.GetLayer(foo_hidden_handle), 0, 10);
EXPECT_EQ(network_states.GetLocal(foo_local_handle).size(), 20);
ExpectDimensions(network_states.GetLayer(foo_logits_handle), 0, 0);
// Add some steps, and check that rows have been added to matrices, while
// vectors are unaffected.
network_states.AddSteps(19);
ExpectDimensions(network_states.GetLayer(foo_hidden_handle), 19, 10);
EXPECT_EQ(network_states.GetLocal(foo_local_handle).size(), 20);
ExpectDimensions(network_states.GetLayer(foo_logits_handle), 19, 19 * 30);
// Again for the next component.
TF_EXPECT_OK(network_states.StartNextComponent(9));
ExpectDimensions(network_states.GetLayer(bar_logits_handle), 0, 40);
ExpectDimensions(network_states.GetLocal(bar_local_handle), 0, 50);
// Add some steps, and check that rows have been added to matrices.
network_states.AddSteps(25);
ExpectDimensions(network_states.GetLayer(bar_logits_handle), 25, 40);
ExpectDimensions(network_states.GetLocal(bar_local_handle), 25, 50);
EXPECT_THAT(network_states.StartNextComponent(10),
test::IsErrorWithSubstr("No next component"));
// Check the layers of the first component. They should still have the same
// dimensions in spite of adding steps to the second component.
ExpectDimensions(network_states.GetLayer(foo_hidden_handle), 19, 10);
ExpectDimensions(network_states.GetLayer(foo_logits_handle), 19, 19 * 30);
}
// Tests that NetworkStates can be reused by resetting them repeatedly, possibly
// switching between different managers.
TEST(NetworkStatesTest, ResetWithDifferentManagers) {
std::vector<NetworkStateManager> managers(10);
std::vector<LayerHandle<int>> layer_handles(10);
std::vector<PairwiseLayerHandle<int>> pairwise_layer_handles(10);
std::vector<LocalVectorHandle<int>> vector_handles(10);
std::vector<LocalMatrixHandle<double>> matrix_handles(10);
for (int dim = 0; dim < 10; ++dim) {
TF_ASSERT_OK(managers[dim].AddComponent("foo"));
TF_ASSERT_OK(managers[dim].AddLayer(
tensorflow::strings::StrCat("layer", dim), dim, &layer_handles[dim]));
TF_ASSERT_OK(
managers[dim].AddLayer(tensorflow::strings::StrCat("pairwise", dim),
dim, &pairwise_layer_handles[dim]));
TF_ASSERT_OK(managers[dim].AddLocal(dim, &vector_handles[dim]));
TF_ASSERT_OK(managers[dim].AddLocal(dim, &matrix_handles[dim]));
}
NetworkStates network_states;
for (int trial = 0; trial < 10; ++trial) {
for (int dim = 0; dim < 10; ++dim) {
network_states.Reset(&managers[dim]);
TF_ASSERT_OK(network_states.StartNextComponent(10));
// Fill the vector local.
Fill(network_states.GetLocal(vector_handles[dim]), dim,
100 * trial + dim);
// Check the vector local.
ExpectFilled(network_states.GetLocal(vector_handles[dim]), dim,
100 * trial + dim);
// Repeatedly add a step and fill it with values.
for (int step = 0; step < 100; ++step) {
network_states.AddStep();
Fill(network_states.GetLayer(layer_handles[dim]).row(step), dim,
1000 * trial + 100 * dim + step);
Fill(network_states.GetLocal(matrix_handles[dim]).row(step), dim,
9876.0 * trial + 100 * dim + step);
}
// Check that data from earlier steps is preserved across reallocations.
for (int step = 0; step < 100; ++step) {
ExpectFilled(network_states.GetLayer(layer_handles[dim]).row(step), dim,
1000 * trial + 100 * dim + step);
ExpectFilled(network_states.GetLocal(matrix_handles[dim]).row(step),
dim, 9876.0 * trial + 100 * dim + step);
}
ExpectDimensions(network_states.GetLayer(pairwise_layer_handles[dim]),
100, 100 * dim);
}
}
}
// Tests that one NetworkStateManager can be shared simultaneously between
// multiple NetworkStates instances.
TEST(NetworkStatesTest, SharedManager) {
const size_t kDim = 17;
NetworkStateManager manager;
LayerHandle<int> layer_handle;
PairwiseLayerHandle<int> pairwise_layer_handle;
LocalVectorHandle<int> vector_handle;
LocalMatrixHandle<double> matrix_handle;
TF_ASSERT_OK(manager.AddComponent("foo"));
TF_ASSERT_OK(manager.AddLayer("layer", kDim, &layer_handle));
TF_ASSERT_OK(manager.AddLayer("pairwise", kDim, &pairwise_layer_handle));
TF_ASSERT_OK(manager.AddLocal(kDim, &vector_handle));
TF_ASSERT_OK(manager.AddLocal(kDim, &matrix_handle));
std::vector<NetworkStates> network_states_vec(10);
for (NetworkStates &network_states : network_states_vec) {
network_states.Reset(&manager);
TF_ASSERT_OK(network_states.StartNextComponent(10));
}
// Fill all vectors.
for (int trial = 0; trial < network_states_vec.size(); ++trial) {
const NetworkStates &network_states = network_states_vec[trial];
Fill(network_states.GetLocal(vector_handle), kDim, 3 * trial);
}
// Check all vectors.
for (int trial = 0; trial < network_states_vec.size(); ++trial) {
const NetworkStates &network_states = network_states_vec[trial];
ExpectFilled(network_states.GetLocal(vector_handle), kDim, 3 * trial);
}
// Fill all matrices. Interleave operations on the network states on each
// step, so all network states are "active" at the same time.
for (int step = 0; step < 100; ++step) {
for (int trial = 0; trial < 10; ++trial) {
NetworkStates &network_states = network_states_vec[trial];
network_states.AddStep();
Fill(network_states.GetLayer(layer_handle).row(step), kDim,
999 * trial + step);
Fill(network_states.GetLocal(matrix_handle).row(step), kDim,
1234.0 * trial + step);
ExpectDimensions(network_states.GetLayer(pairwise_layer_handle), step + 1,
kDim * (step + 1));
}
}
// Check all matrices.
for (int step = 0; step < 100; ++step) {
for (int trial = 0; trial < 10; ++trial) {
const NetworkStates &network_states = network_states_vec[trial];
ExpectFilled(network_states.GetLayer(layer_handle).row(step), kDim,
999 * trial + step);
ExpectFilled(network_states.GetLocal(matrix_handle).row(step), kDim,
1234.0 * trial + step);
ExpectDimensions(network_states.GetLayer(pairwise_layer_handle), 100,
kDim * 100);
}
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit.h"
#include <vector>
#include "tensorflow/core/lib/strings/str_util.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
string NetworkUnit::GetClassName(const ComponentSpec &component_spec) {
// The Python registration API is based on (relative) module paths, such as
// "some.module.FooNetwork". Therefore, we discard the module path prefix and
// use only the final segment, which is the subclass name.
const std::vector<string> segments = tensorflow::str_util::Split(
component_spec.network_unit().registered_name(), ".");
CHECK_GT(segments.size(), 0) << "No network unit name for component spec: "
<< component_spec.ShortDebugString();
return segments.back();
}
} // namespace runtime
} // namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Network Unit",
dragnn::runtime::NetworkUnit);
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_NETWORK_UNIT_H_
#define DRAGNN_RUNTIME_NETWORK_UNIT_H_
#include <stddef.h>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Interface for network units for sequential inference.
class NetworkUnit : public RegisterableClass<NetworkUnit> {
public:
NetworkUnit(const NetworkUnit &that) = delete;
NetworkUnit &operator=(const NetworkUnit &that) = delete;
virtual ~NetworkUnit() = default;
// Returns the network unit class name specified in the |component_spec|.
static string GetClassName(const ComponentSpec &component_spec);
// Initializes this to the configuration in the |component_spec|. Retrieves
// pre-trained variables from the |variable_store|, which must outlive this.
// Adds layers and local operands to the |network_state_manager|, which must
// be positioned at the current component. Requests SessionState extensions
// from the |extension_manager|. On error, returns non-OK.
virtual tensorflow::Status Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) = 0;
// Returns the name of the layer that contains classification logits, or an
// empty string if this does not produce logits. Requires that Initialize()
// was called.
virtual string GetLogitsName() const = 0;
// Evaluates this network unit on the |session_state| and |compute_session|.
// Requires that:
// * The network states in the |session_state| is positioned at the current
// component, which must have at least |step_index|+1 steps.
// * The same component in the |compute_session| must have traversed
// |step_index| transitions.
// * Initialize() was called.
// On error, returns non-OK.
virtual tensorflow::Status Evaluate(
size_t step_index, SessionState *session_state,
ComputeSession *compute_session) const = 0;
protected:
NetworkUnit() = default;
private:
// Helps prevent use of the Create() method; use CreateOrError() instead.
using RegisterableClass<NetworkUnit>::Create;
};
} // namespace runtime
} // namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Network Unit",
dragnn::runtime::NetworkUnit);
} // namespace syntaxnet
// Registers a subclass using its class name as a string.
#define DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::NetworkUnit, #subclass, subclass)
#endif // DRAGNN_RUNTIME_NETWORK_UNIT_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit_base.h"
#include <string.h>
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns the sum of the dimensions of all channels in the |manager|. The
// EmbeddingManager template type should be either FixedEmbeddingManager or
// LinkedEmbeddingManager; note that both share the same API.
template <class EmbeddingManager>
size_t SumEmbeddingDimensions(const EmbeddingManager &manager) {
size_t sum = 0;
for (size_t i = 0; i < manager.num_channels(); ++i) {
sum += manager.embedding_dim(i);
}
return sum;
}
// Copies each channel of the |embeddings| into the region starting at |data|.
// Returns a pointer to one past the last element of the copied region. The
// Embeddings type should be FixedEmbeddings or LinkedEmbeddings; note that both
// have the same API.
//
// TODO(googleuser): Try a vectorized copy instead of memcpy(). Unclear whether
// we can do better, though. For one, the memcpy() implementation may already
// be vectorized. Also, while the input embeddings are aligned, the output is
// not; e.g., consider concatenating inputs with dims 7 and 9. This could be
// addressed by requiring that embedding dims are aligned, or by handling the
// unaligned prefix separately.
//
// TODO(googleuser): Consider alternatives for handling fixed feature channels
// with size>1. The least surprising approach is to concatenate the size>1
// embeddings inside FixedEmbeddings, so the channel IDs still correspond to
// positions in the ComponentSpec.fixed_feature list. However, that means the
// same embedding gets copied twice, once there and once here. Conversely, we
// could split the size>1 embeddings into separate channels, eliding a copy
// while obfuscating the channel IDs. IMO, separate channels seem better
// because very few bits of DRAGNN actually access individual channels, and I
// wrote many of those bits.
template <class Embeddings>
float *CopyEmbeddings(const Embeddings &embeddings, float *data) {
for (size_t i = 0; i < embeddings.num_embeddings(); ++i) {
const Vector<float> vector = embeddings.embedding(i);
memcpy(data, vector.data(), vector.size() * sizeof(float));
data += vector.size();
}
return data;
}
} // namespace
tensorflow::Status NetworkUnitBase::InitializeBase(
bool use_concatenated_input, const ComponentSpec &component_spec,
VariableStore *variable_store, NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
use_concatenated_input_ = use_concatenated_input;
num_actions_ = component_spec.num_actions();
TF_RETURN_IF_ERROR(fixed_embedding_manager_.Reset(
component_spec, variable_store, network_state_manager));
TF_RETURN_IF_ERROR(linked_embedding_manager_.Reset(
component_spec, variable_store, network_state_manager));
concatenated_input_dim_ = SumEmbeddingDimensions(fixed_embedding_manager_) +
SumEmbeddingDimensions(linked_embedding_manager_);
if (use_concatenated_input_) {
// If there is <= 1 input embedding, then the concatenation is trivial and
// we don't need a local vector; see ConcatenateInput().
const size_t num_embeddings = fixed_embedding_manager_.num_embeddings() +
linked_embedding_manager_.num_embeddings();
if (num_embeddings > 1) {
TF_RETURN_IF_ERROR(network_state_manager->AddLocal(
concatenated_input_dim_, &concatenated_input_handle_));
}
// Check that all fixed features are embedded.
for (size_t i = 0; i < fixed_embedding_manager_.num_channels(); ++i) {
if (!fixed_embedding_manager_.is_embedded(i)) {
return tensorflow::errors::InvalidArgument(
"Non-embedded fixed features cannot be concatenated");
}
}
}
extension_manager->GetShared(&fixed_embeddings_handle_);
extension_manager->GetShared(&linked_embeddings_handle_);
return tensorflow::Status::OK();
}
tensorflow::Status NetworkUnitBase::EvaluateBase(
SessionState *session_state, ComputeSession *compute_session,
Vector<float> *concatenated_input) const {
FixedEmbeddings &fixed_embeddings =
session_state->extensions.Get(fixed_embeddings_handle_);
LinkedEmbeddings &linked_embeddings =
session_state->extensions.Get(linked_embeddings_handle_);
TF_RETURN_IF_ERROR(fixed_embeddings.Reset(&fixed_embedding_manager_,
session_state->network_states,
compute_session));
TF_RETURN_IF_ERROR(linked_embeddings.Reset(&linked_embedding_manager_,
session_state->network_states,
compute_session));
if (use_concatenated_input_ && concatenated_input != nullptr) {
*concatenated_input = ConcatenateInput(session_state);
}
return tensorflow::Status::OK();
}
Vector<float> NetworkUnitBase::ConcatenateInput(
SessionState *session_state) const {
DCHECK(use_concatenated_input_);
const FixedEmbeddings &fixed_embeddings =
session_state->extensions.Get(fixed_embeddings_handle_);
const LinkedEmbeddings &linked_embeddings =
session_state->extensions.Get(linked_embeddings_handle_);
const size_t num_embeddings =
fixed_embeddings.num_embeddings() + linked_embeddings.num_embeddings();
// Special cases where no actual concatenation is required.
if (num_embeddings == 0) return {};
if (num_embeddings == 1) {
return fixed_embeddings.num_embeddings() > 0
? fixed_embeddings.embedding(0)
: linked_embeddings.embedding(0);
}
// General case; concatenate into a local vector. The ordering of embeddings
// must be exactly the same as in the Python codebase, which is:
// 1. Fixed embeddings before linked embeddings (see get_input_tensor() in
// network_units.py).
// 2. In each type, ordered as listed in ComponentSpec.fixed/linked_feature
// (see DynamicComponentBuilder._feedforward_unit() in component.py).
//
// Since FixedEmbeddings and LinkedEmbeddings already follow the order defined
// in the ComponentSpec, it suffices to append each fixed embedding, then each
// linked embedding.
const MutableVector<float> concatenation =
session_state->network_states.GetLocal(concatenated_input_handle_);
float *data = concatenation.data();
data = CopyEmbeddings(fixed_embeddings, data);
data = CopyEmbeddings(linked_embeddings, data);
DCHECK_EQ(data, concatenation.end());
return Vector<float>(concatenation);
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_NETWORK_UNIT_BASE_H_
#define DRAGNN_RUNTIME_NETWORK_UNIT_BASE_H_
#include <stddef.h>
#include <utility>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A base class for network units that provides common functionality, analogous
// to NetworkUnitInterface.__init__() in network_units.py. Specifically, this
// class manages and builds input embeddings and, as an convenience, optionally
// concatenates the input embeddings into a single vector.
//
// Since recurrent layers are both outputs and inputs, they complicate network
// unit initialization. In particular, the linked embeddings cannot be set up
// until the charateristics of all recurrently-accessible layers are known. On
// the other hand, some layers cannot be initialized until all inputs, including
// the linked embeddings, are set up. For example, the IdentityNetwork outputs
// a layer whose dimension is the sum of all input dimensions.
//
// To accommodate recurrent layers, network unit initialization is organized
// into three phases:
// 1. (Subclass) Initialize all recurrently-accessible layers.
// 2. (This class) Initialize embedding managers and other common state.
// 3. (Subclass) Initialize any non-recurrent layers.
//
// Concretely, the subclass's Initialize() should first add recurrent layers,
// then call InitializeBase(), and finally finish initializing. Evaluation is
// simpler: the subclass's Evaluate() may call EvaluateBase() at any time.
//
// Note: Network unit initialization is similarly interleaved between base and
// subclasses in the Python codebase; see NetworkUnitInterface.get_layer_size()
// and the "init_layers" argument to NetworkUnitInterface.__init__().
class NetworkUnitBase : public NetworkUnit {
public:
// Initializes common state as configured in the |component_spec|. Retrieves
// pre-trained embedding matrices from the |variable_store|. Looks up linked
// embeddings in the |network_state_manager|, which must contain all recurrent
// layers. Requests any required extensions from the |extension_manager|. If
// |use_concatenated_input| is true, prepares to concatenate input embeddings
// in EvaluateBase(). On error, returns non-OK.
tensorflow::Status InitializeBase(bool use_concatenated_input,
const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager);
// Resets the fixed and linked embeddings in the |session_state| using its
// network states and the |compute_session|. Requires that InitializeBase()
// was called. If this was prepared for concatenation (see InitializeBase())
// and if |concatenated_input| is non-null, points it at the concatenation of
// the fixed and linked embeddings. Otherwise, no concatenation occurs. On
// error, returns non-OK.
tensorflow::Status EvaluateBase(SessionState *session_state,
ComputeSession *compute_session,
Vector<float> *concatenated_input) const;
// Accessors. All require that InitializeBase() was called.
const FixedEmbeddingManager &fixed_embedding_manager() const;
const LinkedEmbeddingManager &linked_embedding_manager() const;
size_t num_actions() const { return num_actions_; }
size_t concatenated_input_dim() const { return concatenated_input_dim_; }
private:
// Returns the concatenation of the fixed and linked embeddings in the
// |seesion_state|. Requires that |use_concatenated_input_| is true.
Vector<float> ConcatenateInput(SessionState *session_state) const;
// Managers for fixed and linked embeddings in this component.
FixedEmbeddingManager fixed_embedding_manager_;
LinkedEmbeddingManager linked_embedding_manager_;
// Fixed and linked embeddings.
SharedExtensionHandle<FixedEmbeddings> fixed_embeddings_handle_;
SharedExtensionHandle<LinkedEmbeddings> linked_embeddings_handle_;
// Number of actions supported by the transition system.
size_t num_actions_ = 0;
// Sum of dimensions of all fixed and linked embeddings.
size_t concatenated_input_dim_ = 0;
// Whether to concatenate the input embeddings.
bool use_concatenated_input_ = false;
// Handle of the vector that holds the concatenated input, or invalid if no
// concatenation is required.
LocalVectorHandle<float> concatenated_input_handle_;
};
// Implementation details below.
inline const FixedEmbeddingManager &NetworkUnitBase::fixed_embedding_manager()
const {
return fixed_embedding_manager_;
}
inline const LinkedEmbeddingManager &NetworkUnitBase::linked_embedding_manager()
const {
return linked_embedding_manager_;
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_NETWORK_UNIT_BASE_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit_base.h"
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::_;
using ::testing::Invoke;
using ::testing::Return;
// Dimensions of the layers in the network.
static constexpr size_t kPreviousDim = 77;
static constexpr size_t kRecurrentDim = 123;
// Contents of the layers in the network.
static constexpr float kPreviousValue = -2.75;
static constexpr float kRecurrentValue = 6.25;
// Number of steps taken in each component.
static constexpr size_t kNumSteps = 10;
// A trivial network unit that exposes the concatenated inputs. Note that
// NetworkUnitBase does not override the interface methods, so we need a
// concrete subclass for testing.
class FooNetwork : public NetworkUnitBase {
public:
void RequestConcatenation() { request_concatenation_ = true; }
void ProvideConcatenatedInput() { provide_concatenated_input_ = true; }
Vector<float> concatenated_input() const { return concatenated_input_; }
// Implements NetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
TF_RETURN_IF_ERROR(network_state_manager->AddLayer(
"recurrent_layer", kRecurrentDim, &recurrent_handle_));
return InitializeBase(request_concatenation_, component_spec,
variable_store, network_state_manager,
extension_manager);
}
string GetLogitsName() const override { return ""; }
tensorflow::Status Evaluate(size_t unused_step_index,
SessionState *session_state,
ComputeSession *compute_session) const override {
return EvaluateBase(
session_state, compute_session,
provide_concatenated_input_ ? &concatenated_input_ : nullptr);
}
private:
bool request_concatenation_ = false;
bool provide_concatenated_input_ = false;
LayerHandle<float> recurrent_handle_;
mutable Vector<float> concatenated_input_; // Evaluate() sets this
};
class NetworkUnitBaseTest : public NetworkTestBase {
protected:
// Initializes the |network_unit_| based on the |component_spec_text| and
// evaluates it. On error, returns non-OK.
tensorflow::Status Run(const string &component_spec_text) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
component_spec.set_name(kTestComponentName);
AddComponent("previous_component");
AddLayer("previous_layer", kPreviousDim);
AddComponent(kTestComponentName);
TF_RETURN_IF_ERROR(
network_unit_.Initialize(component_spec, &variable_store_,
&network_state_manager_, &extension_manager_));
// Create and populate the network states.
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
StartComponent(kNumSteps);
FillLayer("previous_component", "previous_layer", kPreviousValue);
FillLayer(kTestComponentName, "recurrent_layer", kRecurrentValue);
session_state_.extensions.Reset(&extension_manager_);
// Neither FooNetwork nor NetworkUnitBase look at the step index, so use an
// arbitrary value.
return network_unit_.Evaluate(0, &session_state_, &compute_session_);
}
FooNetwork network_unit_;
std::vector<std::vector<float>> concatenated_inputs_;
};
// Tests that NetworkUnitBase produces an empty vector when concatenating and
// there are no input embeddings.
TEST_F(NetworkUnitBaseTest, ConcatenateNoInputs) {
network_unit_.RequestConcatenation();
network_unit_.ProvideConcatenatedInput();
TF_ASSERT_OK(Run(""));
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 0);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 0);
EXPECT_EQ(network_unit_.num_actions(), 0);
EXPECT_EQ(network_unit_.concatenated_input_dim(), 0);
EXPECT_TRUE(network_unit_.concatenated_input().empty());
}
// Tests that NetworkUnitBase produces a copy of the single input embedding when
// concatenating a single fixed channel.
TEST_F(NetworkUnitBaseTest, ConcatenateOneFixedChannel) {
const float kEmbedding = 1.5;
const float kFeature = 0.5;
const size_t kDim = 13;
const string kSpec = R"(num_actions: 42
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
})";
AddFixedEmbeddingMatrix(0, 11, kDim, kEmbedding);
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{1, kFeature}})));
const float kValue = kEmbedding * kFeature;
network_unit_.RequestConcatenation();
network_unit_.ProvideConcatenatedInput();
TF_ASSERT_OK(Run(kSpec));
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 1);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 0);
EXPECT_EQ(network_unit_.num_actions(), 42);
EXPECT_EQ(network_unit_.concatenated_input_dim(), kDim);
ExpectVector(network_unit_.concatenated_input(),
network_unit_.concatenated_input_dim(), kValue);
}
// Tests that NetworkUnitBase does not concatenate if concatenation is requested
// and the concatenated input vector is not provided.
TEST_F(NetworkUnitBaseTest, ConcatenatedInputVectorNotProvided) {
const float kEmbedding = 1.5;
const float kFeature = 0.5;
const size_t kDim = 13;
const string kSpec = R"(num_actions: 37
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
})";
AddFixedEmbeddingMatrix(0, 11, kDim, kEmbedding);
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{1, kFeature}})));
network_unit_.RequestConcatenation();
TF_ASSERT_OK(Run(kSpec));
// Embedding managers and other config is set up properly.
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 1);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 0);
EXPECT_EQ(network_unit_.num_actions(), 37);
EXPECT_EQ(network_unit_.concatenated_input_dim(), kDim);
// But the concatenation was not performed.
EXPECT_TRUE(network_unit_.concatenated_input().empty());
}
// As above, but with the converse condition: does not request concatenation,
// but does provide the concatenated input vector.
TEST_F(NetworkUnitBaseTest, ConcatenationNotRequested) {
const float kEmbedding = 1.5;
const float kFeature = 0.5;
const size_t kDim = 13;
const string kSpec = R"(num_actions: 31
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
})";
AddFixedEmbeddingMatrix(0, 11, kDim, kEmbedding);
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{1, kFeature}})));
network_unit_.ProvideConcatenatedInput();
TF_ASSERT_OK(Run(kSpec));
// Embedding managers and other config is set up properly.
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 1);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 0);
EXPECT_EQ(network_unit_.num_actions(), 31);
EXPECT_EQ(network_unit_.concatenated_input_dim(), kDim);
// But the concatenation was not performed.
EXPECT_TRUE(network_unit_.concatenated_input().empty());
}
// Tests that NetworkUnitBase produces a copy of the single input embedding when
// concatenating a single linked channel.
TEST_F(NetworkUnitBaseTest, ConcatenateOneLinkedChannel) {
const string kSpec = R"(num_actions: 37
linked_feature {
embedding_dim: -1
source_component: 'previous_component'
source_layer: 'previous_layer'
size: 1
})";
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: 5"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
network_unit_.RequestConcatenation();
network_unit_.ProvideConcatenatedInput();
TF_ASSERT_OK(Run(kSpec));
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 0);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 1);
EXPECT_EQ(network_unit_.num_actions(), 37);
EXPECT_EQ(network_unit_.concatenated_input_dim(), kPreviousDim);
ExpectVector(network_unit_.concatenated_input(),
network_unit_.concatenated_input_dim(), kPreviousValue);
}
// Tests that NetworkUnitBase concatenates a fixed and linked channel in that
// order.
TEST_F(NetworkUnitBaseTest, ConcatenateOneChannelOfEachType) {
const float kEmbedding = 1.25;
const float kFeature = 0.75;
const size_t kFixedDim = 13;
const string kSpec = R"(num_actions: 77
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'previous_component'
source_layer: 'previous_layer'
size: 1
})";
AddFixedEmbeddingMatrix(0, 11, kFixedDim, kEmbedding);
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{1, kFeature}})));
const float kFixedValue = kEmbedding * kFeature;
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: 5"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
network_unit_.RequestConcatenation();
network_unit_.ProvideConcatenatedInput();
TF_ASSERT_OK(Run(kSpec));
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 1);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 1);
EXPECT_EQ(network_unit_.num_actions(), 77);
EXPECT_EQ(network_unit_.concatenated_input_dim(), kFixedDim + kPreviousDim);
// Check that each sub-segment is equal to one of the input embeddings.
const Vector<float> input = network_unit_.concatenated_input();
EXPECT_EQ(input.size(), network_unit_.concatenated_input_dim());
size_t index = 0;
size_t end = kFixedDim;
for (; index < end; ++index) EXPECT_EQ(input[index], kFixedValue);
end += kPreviousDim;
for (; index < end; ++index) EXPECT_EQ(input[index], kPreviousValue);
}
// Tests that NetworkUnitBase produces a properly-ordered concatenation of
// multiple fixed and linked channels, including a recurrent channel.
TEST_F(NetworkUnitBaseTest, ConcatenateMultipleChannelsOfEachType) {
const float kEmbedding0 = 1.25;
const float kEmbedding1 = -0.125;
const float kFeature0 = 0.75;
const float kFeature1 = -2.5;
const size_t kFixedDim0 = 13;
const size_t kFixedDim1 = 19;
const string kSpec = R"(num_actions: 99
fixed_feature {
vocabulary_size: 11
embedding_dim: 13
size: 1
}
fixed_feature {
vocabulary_size: 17
embedding_dim: 19
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'previous_component'
source_layer: 'previous_layer'
size: 1
}
linked_feature {
embedding_dim: -1
source_component: 'test_component'
source_layer: 'recurrent_layer'
size: 1
})";
AddFixedEmbeddingMatrix(0, 11, kFixedDim0, kEmbedding0);
AddFixedEmbeddingMatrix(1, 17, kFixedDim1, kEmbedding1);
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{1, kFeature0}})))
.WillOnce(Invoke(ExtractFeatures(1, {{1, kFeature1}})));
const float kFixedValue0 = kEmbedding0 * kFeature0;
const float kFixedValue1 = kEmbedding1 * kFeature1;
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: 5"})))
.WillOnce(Invoke(ExtractLinks(1, {"step_idx: 6"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
network_unit_.RequestConcatenation();
network_unit_.ProvideConcatenatedInput();
TF_ASSERT_OK(Run(kSpec));
EXPECT_EQ(network_unit_.fixed_embedding_manager().num_channels(), 2);
EXPECT_EQ(network_unit_.linked_embedding_manager().num_channels(), 2);
EXPECT_EQ(network_unit_.num_actions(), 99);
EXPECT_EQ(network_unit_.concatenated_input_dim(),
kFixedDim0 + kFixedDim1 + kPreviousDim + kRecurrentDim);
// Check that each sub-segment is equal to one of the input embeddings. For
// compatibility with the Python codebase, fixed channels must appear before
// linked channels, and among each type order follows the ComponentSpec.
const Vector<float> input = network_unit_.concatenated_input();
EXPECT_EQ(input.size(), network_unit_.concatenated_input_dim());
size_t index = 0;
size_t end = kFixedDim0;
for (; index < end; ++index) EXPECT_EQ(input[index], kFixedValue0);
end += kFixedDim1;
for (; index < end; ++index) EXPECT_EQ(input[index], kFixedValue1);
end += kPreviousDim;
for (; index < end; ++index) EXPECT_EQ(input[index], kPreviousValue);
end += kRecurrentDim;
for (; index < end; ++index) EXPECT_EQ(input[index], kRecurrentValue);
}
// Tests that NetworkUnitBase refuses to concatenate if there are non-embedded
// fixed embeddings.
TEST_F(NetworkUnitBaseTest, CannotConcatenateNonEmbeddedFixedFeatures) {
const string kBadSpec = R"(fixed_feature {
embedding_dim: -1
size: 1
})";
network_unit_.RequestConcatenation();
network_unit_.ProvideConcatenatedInput();
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr(
"Non-embedded fixed features cannot be concatenated"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/network_unit.h"
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Expects that the two pointers have the same address.
void ExpectSameAddress(const void *pointer1, const void *pointer2) {
EXPECT_EQ(pointer1, pointer2);
}
// A trivial implementation for tests.
class FooNetwork : public NetworkUnit {
public:
// Implements NetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return tensorflow::Status::OK();
}
string GetLogitsName() const override { return "foo_logits"; }
tensorflow::Status Evaluate(size_t step_index, SessionState *session_state,
ComputeSession *compute_session) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT(FooNetwork);
// Tests that a human-friendly error is produced for empty network units.
TEST(NetworkUnitTest, GetClassNameDegenerateName) {
ComponentSpec component_spec;
EXPECT_DEATH(NetworkUnit::GetClassName(component_spec),
"No network unit name for component spec");
}
// Tests that NetworkUnit::GetClassName() resolves names properly.
TEST(NetworkUnitTest, GetClassName) {
for (const string &registered_name :
{"FooNetwork",
"module.FooNetwork",
"some.long.path.to.module.FooNetwork"}) {
ComponentSpec component_spec;
component_spec.mutable_network_unit()->set_registered_name(registered_name);
EXPECT_EQ(NetworkUnit::GetClassName(component_spec), "FooNetwork");
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/operands.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
OperandHandle OperandManager::Add(const OperandSpec &spec) {
const size_t index = specs_.size();
specs_.push_back(spec);
switch (spec.type) {
case OperandType::kSingular:
handle_index_to_typed_index_.push_back(singular_spans_.size());
singular_spans_.emplace_back(singular_size_, spec.size);
singular_size_ += PadToAlignment(spec.size);
break;
case OperandType::kStepwise:
handle_index_to_typed_index_.push_back(stepwise_spans_.size());
stepwise_spans_.emplace_back(stepwise_stride_, spec.size);
stepwise_stride_ += PadToAlignment(spec.size);
break;
case OperandType::kPairwise:
handle_index_to_typed_index_.push_back(pairwise_sizes_.size());
pairwise_sizes_.push_back(spec.size);
break;
}
return OperandHandle(index);
}
void Operands::Reset(const OperandManager *manager,
size_t pre_allocate_num_steps) {
manager_ = manager;
handle_index_to_typed_index_ = manager_->handle_index_to_typed_index_;
stepwise_spans_ = manager_->stepwise_spans_;
stepwise_stride_ = manager_->stepwise_stride_;
pairwise_sizes_ = manager_->pairwise_sizes_;
// Allocate and parcel out singular operands.
singular_operands_.clear();
singular_operands_.reserve(manager_->singular_spans_.size());
singular_array_.Reserve(manager_->singular_size_);
char *data = singular_array_.view().data();
for (const auto &span : manager_->singular_spans_) {
singular_operands_.push_back(
MutableAlignedView(data + span.first, span.second));
}
// Pre-allocate and parcel out stepwise operands.
stepwise_operands_.clear();
stepwise_operands_.reserve(stepwise_spans_.size());
stepwise_array_.Reserve(stepwise_stride_ * pre_allocate_num_steps);
data = stepwise_array_.view().data();
for (const auto &span : stepwise_spans_) {
stepwise_operands_.push_back(MutableAlignedArea(
data + span.first, 0, span.second, stepwise_stride_));
}
// Create empty pairwise operands.
pairwise_operands_.clear();
pairwise_operands_.resize(pairwise_sizes_.size());
}
void Operands::AddSteps(size_t num_steps) {
AddStepwiseSteps(num_steps);
AddPairwiseSteps(num_steps);
}
void Operands::AddStepwiseSteps(size_t num_steps) {
if (stepwise_operands_.empty()) return;
// Make room for the new steps.
const size_t new_num_views = stepwise_operands_[0].num_views_ + num_steps;
const bool actually_reallocated =
stepwise_array_.Resize(new_num_views * stepwise_stride_);
// Update the base pointers for stepwise operands, if changed.
if (actually_reallocated) {
char *data = stepwise_array_.view().data();
for (size_t i = 0; i < stepwise_operands_.size(); ++i) {
stepwise_operands_[i].data_ = data + stepwise_spans_[i].first;
}
}
// Update the number of views in each stepwise operand.
for (MutableAlignedArea &operand : stepwise_operands_) {
operand.num_views_ = new_num_views;
}
}
void Operands::AddPairwiseSteps(size_t num_steps) {
if (pairwise_operands_.empty()) return;
const size_t new_num_steps = pairwise_operands_[0].num_views_ + num_steps;
// Set dimensions for each pairwise operand and accumulate their total stride.
size_t new_stride = 0;
for (size_t i = 0; i < pairwise_operands_.size(); ++i) {
const size_t new_view_size = new_num_steps * pairwise_sizes_[i];
pairwise_operands_[i].num_views_ = new_num_steps;
pairwise_operands_[i].view_size_ = new_view_size;
new_stride += PadToAlignment(new_view_size);
}
// Note that Reset() does not preserve the existing array and its contents.
// Although preserving existing data would be nice, it is complex because
// pairwise operands grow in both dimensions. In addition, users should be
// allocating pairwise operands in one shot for speed reasons, in which case
// there is no existing data anyways.
pairwise_array_.Reset(new_num_steps * new_stride);
// Set the new base pointer and stride on each pairwise operand.
char *data = pairwise_array_.view().data();
for (MutableAlignedArea &operand : pairwise_operands_) {
operand.data_ = data;
operand.view_stride_ = new_stride;
data += PadToAlignment(operand.view_size_);
}
DCHECK_EQ(data - pairwise_array_.view().data(), new_stride);
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for declaring and allocating operands. An operand is made up of
// aligned byte arrays, and can be used as an input, output, or intermediate
// value in some computation.
#ifndef DRAGNN_RUNTIME_OPERANDS_H_
#define DRAGNN_RUNTIME_OPERANDS_H_
#include <stddef.h>
#include <stdint.h>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/runtime/alignment.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Possible types of operands.
enum class OperandType {
// A single byte array. For example, an intermediate value that is computed
// once per transition step. Since it is not an output, the same storage
// could be reused across all steps.
kSingular,
// A sequence of identically-sized byte arrays, one per transition step. For
// example, a layer containing one activation vector per step.
kStepwise,
// A grid with one byte array for each pair of transition steps, including
// self pairings. The byte arrays are grouped and concatenated in "rows",
// forming one byte array per step. For example, if there are N steps and D
// bytes per pair, the operand would have N arrays of size N*D bytes. In a
// basic attention model with one "similarity" between pairs of steps, one
// might use a pairwise operand with D=sizeof(float). For best performance,
// use Operands::AddSteps() to allocate all steps at once when working with
// pairwise operands.
kPairwise,
};
// A specification of a operand.
struct OperandSpec {
// Creates a trivial specification.
OperandSpec() = default;
// Creates a specification with the |type| and |size|.
OperandSpec(OperandType type, size_t size) : type(type), size(size) {}
// Type of the operand.
OperandType type = OperandType::kSingular;
// Size of each aligned byte array in the operand.
size_t size = 0;
};
// An opaque handle to an operand.
class OperandHandle;
// A class that manages a set of operand specifications and associates each
// operand with a handle. Operand contents can be retrieved using these
// handles; see Operands below.
class OperandManager {
public:
// Creates an empty manager.
OperandManager() = default;
// Adds an operand configured according to the |spec| and returns its handle.
OperandHandle Add(const OperandSpec &spec);
// Accessors.
const OperandSpec &spec(OperandHandle handle) const;
private:
friend class Operands;
// Specification of each operand.
std::vector<OperandSpec> specs_;
// Mapping from the handle index of an operand to its index amongst operands
// of the same type.
std::vector<size_t> handle_index_to_typed_index_;
// Span of each singular operand, as a (start-offset,size) pair, relative to
// the byte array containing all singular operands.
std::vector<std::pair<size_t, size_t>> singular_spans_;
// Span of each stepwise operand, as a (start-offset,size) pair, relative to
// the byte array for each step.
std::vector<std::pair<size_t, size_t>> stepwise_spans_;
// Size of each pairwise operand.
std::vector<size_t> pairwise_sizes_;
// Number of bytes used by all singular operands, including alignment padding.
size_t singular_size_ = 0;
// Number of bytes used by all stepwise operands on each step, including
// alignment padding.
size_t stepwise_stride_ = 0;
};
// A set of operands. The structure of the operands is configured by an
// OperandManager, and operand contents can be accessed using the handles
// produced by the manager.
//
// Multiple Operands instances can share the same OperandManager. In addition,
// an Operands instance can be reused by repeatedly Reset()-ing it, potentially
// with different OperandManagers. Such reuse can reduce allocation overhead.
class Operands {
public:
// Creates an empty set.
Operands() = default;
// Resets this to the operands defined by the |manager|. The |manager| must
// live until this is destroyed or Reset() again, and should not be modified
// during that time. Stepwise and pairwise operands start with 0 steps; use
// AddStep() to extend them. Pre-allocates stepwise operands so that they
// will not be reallocated during the first |pre_allocate_num_steps| calls to
// AddStep(). Invalidates all previously-returned operands.
void Reset(const OperandManager *manager, size_t pre_allocate_num_steps);
// Extends stepwise and pairwise operands by one or more steps. Requires that
// Reset() was called. Invalidates any previously-returned views of stepwise
// and pairwise operands. Preserves data for pre-existing steps of stepwise
// operands, but not for pre-existing pairwise operands. In general, pairwise
// operands should be allocated in one shot, not incrementally.
void AddStep() { AddSteps(1); }
void AddSteps(size_t num_steps);
// Returns the singular operand associated with the |handle|. The returned
// view is invalidated by Reset().
MutableAlignedView GetSingular(OperandHandle handle) const;
// Returns the stepwise operand associated with the |handle|. The returned
// area is invalidated by Reset() and AddStep().
MutableAlignedArea GetStepwise(OperandHandle handle) const;
// Returns the pairwise operand associated with the |handle|. The returned
// area is invalidated by Reset() and AddStep().
MutableAlignedArea GetPairwise(OperandHandle handle) const;
private:
// Extends stepwise operands only; see AddSteps().
void AddStepwiseSteps(size_t num_steps);
// Extends pairwise operands only; see AddSteps().
void AddPairwiseSteps(size_t num_steps);
// Manager of the operands in this set.
const OperandManager *manager_ = nullptr;
// Cached members from the |manager_|.
tensorflow::gtl::ArraySlice<size_t> handle_index_to_typed_index_;
tensorflow::gtl::ArraySlice<std::pair<size_t, size_t>> stepwise_spans_;
size_t stepwise_stride_ = 0;
tensorflow::gtl::ArraySlice<size_t> pairwise_sizes_;
// Byte arrays holding operands of each type. Storage is separated because
// each type grows differently with the number of steps.
UniqueAlignedArray singular_array_;
UniqueAlignedArray stepwise_array_;
UniqueAlignedArray pairwise_array_;
// Lists of operands of each type.
std::vector<MutableAlignedView> singular_operands_;
std::vector<MutableAlignedArea> stepwise_operands_;
std::vector<MutableAlignedArea> pairwise_operands_;
};
// Implementation details below.
// An opaque handle to an operand.
class OperandHandle {
public:
// Creates an invalid handle.
OperandHandle() = default;
private:
friend class OperandManager;
friend class Operands;
// Creates a handle that points to the |index|.
explicit OperandHandle(size_t index) : index_(index) {}
// Index of the operand in its manager.
size_t index_ = SIZE_MAX;
};
inline const OperandSpec &OperandManager::spec(OperandHandle handle) const {
return specs_[handle.index_];
}
inline MutableAlignedView Operands::GetSingular(OperandHandle handle) const {
DCHECK(manager_->spec(handle).type == OperandType::kSingular)
<< "Actual type: " << static_cast<int>(manager_->spec(handle).type);
DCHECK_LE(handle.index_, handle_index_to_typed_index_.size());
return singular_operands_[handle_index_to_typed_index_[handle.index_]];
}
inline MutableAlignedArea Operands::GetStepwise(OperandHandle handle) const {
DCHECK(manager_->spec(handle).type == OperandType::kStepwise)
<< "Actual type: " << static_cast<int>(manager_->spec(handle).type);
DCHECK_LE(handle.index_, handle_index_to_typed_index_.size());
return stepwise_operands_[handle_index_to_typed_index_[handle.index_]];
}
inline MutableAlignedArea Operands::GetPairwise(OperandHandle handle) const {
DCHECK(manager_->spec(handle).type == OperandType::kPairwise)
<< "Actual type: " << static_cast<int>(manager_->spec(handle).type);
DCHECK_LE(handle.index_, handle_index_to_typed_index_.size());
return pairwise_operands_[handle_index_to_typed_index_[handle.index_]];
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_OPERANDS_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/operands.h"
#include <string.h>
#include <tuple>
#include <utility>
#include <vector>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Expects that the two pointers are the same.
void ExpectSameAddress(const void *pointer1, const void *pointer2) {
EXPECT_EQ(pointer1, pointer2);
}
// Sets the |vector| to |size| copies of the |value|.
template <class T>
void Fill(MutableVector<T> vector, size_t size, T value) {
ASSERT_EQ(vector.size(), size);
for (T &element : vector) element = value;
}
// Expects that the |vector| contains |size| copies of the |expected_value|.
template <class T>
void ExpectFilled(Vector<T> vector, size_t size, T expected_value) {
ASSERT_EQ(vector.size(), size);
for (const T element : vector) EXPECT_EQ(element, expected_value);
}
// Tests that OperandManager can add operands and remember their configuration.
TEST(OperandManagerTest, Add) {
OperandManager manager;
const OperandHandle handle1 = manager.Add({OperandType::kSingular, 7});
const OperandHandle handle2 = manager.Add({OperandType::kStepwise, 11});
EXPECT_EQ(manager.spec(handle1).type, OperandType::kSingular);
EXPECT_EQ(manager.spec(handle1).size, 7);
EXPECT_EQ(manager.spec(handle2).type, OperandType::kStepwise);
EXPECT_EQ(manager.spec(handle2).size, 11);
}
// Tests that Operands contains operands whose dimensions match its manager.
TEST(OperandsTest, Dimensions) {
const size_t kDim1 = 3, kDim2 = 41, kDim3 = 19, kDim4 = 77;
OperandManager manager;
const OperandHandle handle1 =
manager.Add({OperandType::kSingular, kDim1 * sizeof(float)});
const OperandHandle handle2 =
manager.Add({OperandType::kStepwise, kDim2 * sizeof(double)});
const OperandHandle handle3 =
manager.Add({OperandType::kSingular, kDim3 * sizeof(float)});
const OperandHandle handle4 =
manager.Add({OperandType::kStepwise, kDim4 * sizeof(int)});
AlignedView view;
AlignedArea area;
Operands operands;
operands.Reset(&manager, 10);
view = operands.GetSingular(handle1);
EXPECT_EQ(view.size(), kDim1 * sizeof(float));
EXPECT_EQ(Vector<float>(view).size(), kDim1);
area = operands.GetStepwise(handle2);
EXPECT_EQ(area.num_views(), 0); // no steps yet
EXPECT_EQ(area.view_size(), kDim2 * sizeof(double));
EXPECT_EQ(Matrix<double>(area).num_rows(), 0); // starts with no steps
EXPECT_EQ(Matrix<double>(area).num_columns(), kDim2);
view = operands.GetSingular(handle3);
EXPECT_EQ(view.size(), kDim3 * sizeof(float));
EXPECT_EQ(Vector<float>(view).size(), kDim3);
area = operands.GetStepwise(handle4);
EXPECT_EQ(area.num_views(), 0); // no steps yet
EXPECT_EQ(area.view_size(), kDim4 * sizeof(int));
EXPECT_EQ(Matrix<int>(area).num_rows(), 0); // starts with no steps
EXPECT_EQ(Matrix<int>(area).num_columns(), kDim4);
}
// Tests that Operands can incrementally extend stepwise operands while
// preserving existing values.
TEST(OperandsTest, AddStepToStepwise) {
const size_t kDim1 = 23, kDim2 = 29;
OperandManager manager;
const OperandHandle handle1 =
manager.Add({OperandType::kStepwise, kDim1 * sizeof(double)});
const OperandHandle handle2 =
manager.Add({OperandType::kStepwise, kDim2 * sizeof(int)});
Operands operands;
operands.Reset(&manager, 10);
// Repeatedly add a step and fill it with values.
for (int i = 0; i < 100; ++i) {
operands.AddStep();
Fill(MutableVector<double>(operands.GetStepwise(handle1).view(i)), kDim1,
1000.0 + i);
Fill(MutableVector<int>(operands.GetStepwise(handle2).view(i)), kDim2,
2000 + i);
}
// Check that data from earlier steps is preserved across reallocations.
for (int i = 0; i < 100; ++i) {
ExpectFilled(Vector<double>(operands.GetStepwise(handle1).view(i)), kDim1,
1000.0 + i);
ExpectFilled(Vector<int>(operands.GetStepwise(handle2).view(i)), kDim2,
2000 + i);
}
}
// Tests that Operands can add multiple steps at once.
TEST(OperandsTest, AddStepsToStepwise) {
const size_t kDim1 = 23, kDim2 = 29;
OperandManager manager;
const OperandHandle handle1 =
manager.Add({OperandType::kStepwise, kDim1 * sizeof(double)});
const OperandHandle handle2 =
manager.Add({OperandType::kStepwise, kDim2 * sizeof(int)});
Operands operands;
operands.Reset(&manager, 10);
// Repeatedly add blocks of steps and fill them with values.
for (int i = 0; i < 100; ++i) {
if (i % 10 == 0) operands.AddSteps(10); // occasionally add a block
Fill(MutableVector<double>(operands.GetStepwise(handle1).view(i)), kDim1,
1000.0 + i);
Fill(MutableVector<int>(operands.GetStepwise(handle2).view(i)), kDim2,
2000 + i);
}
// Check that data from earlier steps is preserved across reallocations.
for (int i = 0; i < 100; ++i) {
ExpectFilled(Vector<double>(operands.GetStepwise(handle1).view(i)), kDim1,
1000.0 + i);
ExpectFilled(Vector<int>(operands.GetStepwise(handle2).view(i)), kDim2,
2000 + i);
}
}
// Tests that Operands can add multiple steps to a pairwise operand.
TEST(OperandsTest, AddStepsPairwise) {
const size_t kDim1 = 4, kDim2 = 31;
OperandManager manager;
const OperandHandle handle1 = manager.Add({OperandType::kPairwise, kDim1});
const OperandHandle handle2 = manager.Add({OperandType::kPairwise, kDim2});
Operands operands;
operands.Reset(&manager, 10);
{ // A 1x1 pairwise operand.
operands.AddSteps(1);
const MutableAlignedArea area1 = operands.GetPairwise(handle1);
const MutableAlignedArea area2 = operands.GetPairwise(handle2);
EXPECT_EQ(area1.num_views(), 1);
EXPECT_EQ(area2.num_views(), 1);
EXPECT_EQ(area1.view_size(), kDim1);
EXPECT_EQ(area2.view_size(), kDim2);
// Write to operands to test the validity of the underlying memory region.
memset(area1.view(0).data(), 0, kDim1);
memset(area2.view(0).data(), 0, kDim2);
}
{ // A 10x10 pairwise operand.
operands.AddSteps(9);
const MutableAlignedArea area1 = operands.GetPairwise(handle1);
const MutableAlignedArea area2 = operands.GetPairwise(handle2);
EXPECT_EQ(area1.num_views(), 10);
EXPECT_EQ(area2.num_views(), 10);
EXPECT_EQ(area1.view_size(), 10 * kDim1);
EXPECT_EQ(area2.view_size(), 10 * kDim2);
// Infer the stride by comparing pointers between consecutive views.
const size_t expected_stride =
PadToAlignment(10 * kDim1) + PadToAlignment(10 * kDim2);
EXPECT_EQ(area1.view(1).data() - area1.view(0).data(), expected_stride);
EXPECT_EQ(area2.view(1).data() - area2.view(0).data(), expected_stride);
// Write to operands to test the validity of the underlying memory region.
memset(area1.view(9).data(), 0, 10 * kDim1);
memset(area2.view(9).data(), 0, 10 * kDim2);
}
}
// Tests that Operands can be reused by resetting them repeatedly, possibly
// switching between different managers.
TEST(OperandsTest, ResetWithDifferentManagers) {
std::vector<OperandManager> managers;
std::vector<std::tuple<OperandHandle, OperandHandle, OperandHandle>> handles;
for (int dim = 0; dim < 10; ++dim) {
managers.emplace_back();
handles.emplace_back(
managers.back().Add({OperandType::kSingular, dim * sizeof(double)}),
managers.back().Add({OperandType::kStepwise, dim * sizeof(int)}),
managers.back().Add({OperandType::kPairwise, dim * sizeof(float)}));
}
Operands operands;
for (int trial = 0; trial < 10; ++trial) {
for (int dim = 0; dim < 10; ++dim) {
operands.Reset(&managers[dim], 10);
const OperandHandle singular_handle = std::get<0>(handles[dim]);
const OperandHandle stepwise_handle = std::get<1>(handles[dim]);
const OperandHandle pairwise_handle = std::get<2>(handles[dim]);
// Fill the singular operand.
Fill(MutableVector<double>(operands.GetSingular(singular_handle)), dim,
100.0 * trial + dim);
// Check the singular operands.
ExpectFilled(Vector<double>(operands.GetSingular(singular_handle)), dim,
100.0 * trial + dim);
// Repeatedly add a step and fill it with values.
for (int step = 0; step < 100; ++step) {
operands.AddStep();
Fill(MutableVector<int>(
operands.GetStepwise(stepwise_handle).view(step)),
dim, 1000 * trial + 100 * dim + step);
}
// Check that data from earlier steps is preserved across reallocations.
for (int step = 0; step < 100; ++step) {
ExpectFilled(
Vector<int>(operands.GetStepwise(stepwise_handle).view(step)), dim,
1000 * trial + 100 * dim + step);
}
// Check the dimensions of pairwise operands.
Matrix<float> pairwise(operands.GetPairwise(pairwise_handle));
EXPECT_EQ(pairwise.num_rows(), 100);
EXPECT_EQ(pairwise.num_columns(), 100 * dim);
}
}
}
// Tests that one OperandManager can be shared simultaneously between multiple
// Operands instances.
TEST(OperandsTest, SharedManager) {
const size_t kDim = 17;
OperandManager manager;
const OperandHandle singular_handle =
manager.Add({OperandType::kSingular, kDim * sizeof(double)});
const OperandHandle stepwise_handle =
manager.Add({OperandType::kStepwise, kDim * sizeof(int)});
std::vector<Operands> operands_vec(10);
for (Operands &operands : operands_vec) operands.Reset(&manager, 10);
// Fill all singular operands.
for (int trial = 0; trial < operands_vec.size(); ++trial) {
const Operands &operands = operands_vec[trial];
Fill(MutableVector<double>(operands.GetSingular(singular_handle)), kDim,
3.0 * trial);
}
// Check all singular operands.
for (int trial = 0; trial < operands_vec.size(); ++trial) {
const Operands &operands = operands_vec[trial];
ExpectFilled(Vector<double>(operands.GetSingular(singular_handle)), kDim,
3.0 * trial);
}
// Fill all stepwise operands. Interleave operations on the operands on each
// step, so all operands are "active" at the same time.
for (int step = 0; step < 100; ++step) {
for (int trial = 0; trial < 10; ++trial) {
Operands &operands = operands_vec[trial];
operands.AddStep();
Fill(MutableVector<int>(operands.GetStepwise(stepwise_handle).view(step)),
kDim, trial * 999 + step);
}
}
// Check all stepwise operands.
for (int step = 0; step < 100; ++step) {
for (int trial = 0; trial < 10; ++trial) {
const Operands &operands = operands_vec[trial];
ExpectFilled(
Vector<int>(operands.GetStepwise(stepwise_handle).view(step)), kDim,
trial * 999 + step);
}
}
}
// Tests that an Operands uses all of the pre-allocated steps and reallocates
// exactly when it exhausts the pre-allocated array.
TEST(OperandsTest, UsesPreAllocatedSteps) {
const size_t kBytes = 5;
const size_t kPreAllocateNumSteps = 10;
OperandManager manager;
const OperandHandle handle = manager.Add({OperandType::kStepwise, kBytes});
Operands operands;
operands.Reset(&manager, kPreAllocateNumSteps);
// The first N steps fit exactly in the pre-allocated array. Access the base
// of the stepwise array via the first view.
operands.AddStep();
char *const pre_allocated_data = operands.GetStepwise(handle).view(0).data();
for (size_t step = 1; step < kPreAllocateNumSteps; ++step) {
operands.AddStep();
ASSERT_EQ(operands.GetStepwise(handle).view(0).data(), pre_allocated_data);
}
// The N+1'st step triggers a reallocation, which is guaranteed to yield a new
// pointer because it creates a separate array and copies into it.
operands.AddStep();
ASSERT_NE(operands.GetStepwise(handle).view(0).data(), pre_allocated_data);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Links to the previous step in the same component. Templated on a bool that
// indicates the direction that the transition system runs in.
template <bool left_to_right>
class RecurrentSequenceLinker : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const override;
tensorflow::Status Initialize(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) override;
tensorflow::Status GetLinks(size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const override;
};
template <bool left_to_right>
bool RecurrentSequenceLinker<left_to_right>::Supports(
const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const {
TransitionSystemTraits traits(component_spec);
// Here, fml="bias" and source_translator="history" are a DRAGNN recipe for
// linking to the previous transition step. More concretely,
// * "bias" always extracts index 0.
// * "history" subtracts the index it is given from (#steps - 1).
// Putting the two together, we link to (#steps - 1 - 0); i.e., the previous
// transition step.
return (channel.fml() == "bias" || channel.fml() == "bias(0)") &&
channel.source_component() == component_spec.name() &&
channel.source_translator() == "history" &&
traits.is_left_to_right == left_to_right && traits.is_sequential;
}
template <bool left_to_right>
tensorflow::Status RecurrentSequenceLinker<left_to_right>::Initialize(
const LinkedFeatureChannel &channel, const ComponentSpec &component_spec) {
return tensorflow::Status::OK();
}
template <bool left_to_right>
tensorflow::Status RecurrentSequenceLinker<left_to_right>::GetLinks(
size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const {
links->resize(source_num_steps);
if (left_to_right) {
int32 index = -1;
for (int32 &link : *links) link = index++;
} else {
int32 index = static_cast<int32>(source_num_steps) - 1;
for (int32 &link : *links) link = --index;
}
return tensorflow::Status::OK();
}
using LeftToRightRecurrentSequenceLinker = RecurrentSequenceLinker<true>;
using RightToLeftRecurrentSequenceLinker = RecurrentSequenceLinker<false>;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(LeftToRightRecurrentSequenceLinker);
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(RightToLeftRecurrentSequenceLinker);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns a ComponentSpec that the linker will support.
ComponentSpec MakeSupportedSpec() {
ComponentSpec component_spec;
component_spec.set_name("test_component");
component_spec.mutable_transition_system()->set_registered_name("shift-only");
LinkedFeatureChannel *channel = component_spec.add_linked_feature();
channel->set_fml("bias");
channel->set_source_component("test_component");
channel->set_source_translator("history");
return component_spec;
}
// Tests that the linker supports appropriate specs.
TEST(RecurrentSequenceLinkerTest, Supported) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "LeftToRightRecurrentSequenceLinker");
(*component_spec.mutable_transition_system()
->mutable_parameters())["left_to_right"] = "false";
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "RightToLeftRecurrentSequenceLinker");
channel.set_fml("bias(0)");
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "RightToLeftRecurrentSequenceLinker");
(*component_spec.mutable_transition_system()
->mutable_parameters())["left_to_right"] = "true";
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "LeftToRightRecurrentSequenceLinker");
}
// Tests that the linker requires the right transition system.
TEST(RecurrentSequenceLinkerTest, WrongTransitionSystem) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
component_spec.mutable_transition_system()->set_registered_name("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right FML.
TEST(RecurrentSequenceLinkerTest, WrongFml) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_fml("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires a recurrent link.
TEST(RecurrentSequenceLinkerTest, WrongSource) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_source_component("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right translator.
TEST(RecurrentSequenceLinkerTest, WrongTranslator) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_source_translator("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker can be initialized and used to extract links.
TEST(RecurrentSequenceLinkerTest, InitializeAndGetLinks) {
const ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
std::unique_ptr<SequenceLinker> linker;
TF_ASSERT_OK(SequenceLinker::New("LeftToRightRecurrentSequenceLinker",
channel, component_spec, &linker));
InputBatchCache input;
std::vector<int32> links = {123, 456, 789}; // gets overwritten
TF_ASSERT_OK(linker->GetLinks(10, &input, &links));
const std::vector<int32> expected_links = {-1, 0, 1, 2, 3, 4, 5, 6, 7, 8};
EXPECT_EQ(links, expected_links);
}
// Tests that the links are reversed for right-to-left components.
TEST(RecurrentSequenceLinkerTest, InitializeAndGetLinksRightToLeft) {
ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
std::unique_ptr<SequenceLinker> linker;
TF_ASSERT_OK(SequenceLinker::New("RightToLeftRecurrentSequenceLinker",
channel, component_spec, &linker));
InputBatchCache input;
std::vector<int32> links = {123, 456, 789}; // gets overwritten
TF_ASSERT_OK(linker->GetLinks(10, &input, &links));
const std::vector<int32> expected_links = {8, 7, 6, 5, 4, 3, 2, 1, 0, -1};
EXPECT_EQ(links, expected_links);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Applies a reversed identity function.
class ReversedSequenceLinker : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const override;
tensorflow::Status Initialize(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) override;
tensorflow::Status GetLinks(size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const override;
};
bool ReversedSequenceLinker::Supports(
const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const {
TransitionSystemTraits traits(component_spec);
// Note: Add more "||" clauses as needed.
return ((channel.fml() == "input.focus" &&
channel.source_translator() == "reverse-token") ||
(channel.fml() == "char-input.focus" &&
channel.source_translator() == "reverse-char")) &&
traits.is_sequential;
}
tensorflow::Status ReversedSequenceLinker::Initialize(
const LinkedFeatureChannel &channel, const ComponentSpec &component_spec) {
return tensorflow::Status::OK();
}
tensorflow::Status ReversedSequenceLinker::GetLinks(
size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const {
links->resize(source_num_steps);
int32 index = links->size();
for (int32 &link : *links) link = --index;
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(ReversedSequenceLinker);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns a ComponentSpec that the linker will support.
ComponentSpec MakeSupportedSpec() {
ComponentSpec component_spec;
component_spec.mutable_transition_system()->set_registered_name("shift-only");
LinkedFeatureChannel *channel = component_spec.add_linked_feature();
channel->set_fml("input.focus");
channel->set_source_translator("reverse-token");
return component_spec;
}
// Tests that the linker supports appropriate specs.
TEST(ReversedSequenceLinkerTest, Supported) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "ReversedSequenceLinker");
channel.set_fml("char-input.focus");
channel.set_source_translator("reverse-char");
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "ReversedSequenceLinker");
}
// Tests that the linker requires the right transition system.
TEST(IdentitySequenceLinkerTest, WrongTransitionSystem) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
component_spec.mutable_transition_system()->set_registered_name("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right FML.
TEST(ReversedSequenceLinkerTest, WrongFml) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_fml("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right translator.
TEST(ReversedSequenceLinkerTest, WrongTranslator) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_source_translator("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right combination of FML and translator.
TEST(ReversedSequenceLinkerTest, MismatchedFmlAndTranslator) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_fml("input.focus");
channel.set_source_translator("reverse-char");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
channel.set_fml("char-input.focus");
channel.set_source_translator("reverse-token");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker can be initialized and used to extract links.
TEST(ReversedSequenceLinkerTest, InitializeAndGetLinks) {
const ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
std::unique_ptr<SequenceLinker> linker;
TF_ASSERT_OK(SequenceLinker::New("ReversedSequenceLinker", channel,
component_spec, &linker));
InputBatchCache input;
std::vector<int32> links = {123, 456, 789}; // gets overwritten
TF_ASSERT_OK(linker->GetLinks(10, &input, &links));
const std::vector<int32> expected_links = {9, 8, 7, 6, 5, 4, 3, 2, 1, 0};
EXPECT_EQ(links, expected_links);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/component_transformation.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Transformer that selects the best component subclass for the ComponentSpec.
class SelectBestComponentTransformer : public ComponentTransformer {
public:
// Implements ComponentTransformer.
tensorflow::Status Transform(const string &component_type,
ComponentSpec *component_spec) override {
string best_component_type;
TF_RETURN_IF_ERROR(
Component::Select(*component_spec, &best_component_type));
component_spec->mutable_component_builder()->set_registered_name(
best_component_type);
if (component_type != best_component_type) {
LOG(INFO) << "Component '" << component_spec->name()
<< "' builder updated from " << component_type << " to "
<< best_component_type << ".";
} else {
VLOG(2) << "Component '" << component_spec->name() << "' builder type "
<< component_type << " unchanged.";
}
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(SelectBestComponentTransformer);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment