Commit edea2b67 authored by Terry Koo's avatar Terry Koo
Browse files

Remove runtime because reasons.

parent a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for extracting information from FML specifications.
#ifndef DRAGNN_RUNTIME_FML_PARSING_H_
#define DRAGNN_RUNTIME_FML_PARSING_H_
#include <string>
#include <vector>
#include "dragnn/runtime/attributes.h"
#include "syntaxnet/base.h"
#include "syntaxnet/feature_extractor.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Attributes that can be parsed from a feature descriptor.
class FeatureFunctionAttributes : public Attributes {
public:
// Parses registered attributes from the parameters of the |function|. On
// error, returns non-OK.
tensorflow::Status Reset(const FeatureFunctionDescriptor &function);
};
// Parses the |fml| as a chain of nested features matching the |types|. All of
// the features must have no parameters, except the innermost, whose descriptor
// is set to |leaf|. On error, returns non-OK and modifies nothing.
tensorflow::Status ParseFeatureChainFml(const string &fml,
const std::vector<string> &types,
FeatureFunctionDescriptor *leaf);
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_FML_PARSING_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/fml_parsing.h"
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "syntaxnet/base.h"
#include "syntaxnet/feature_extractor.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Attributes for the test.
struct TestAttributes : public FeatureFunctionAttributes {
Optional<int32> foo{"foo", -1, this};
Mandatory<float> bar{"bar", this};
};
// Tests that attributes can be parsed from a valid feature descriptor.
TEST(FeatureFunctionAttributesTest, ValidDescriptor) {
FeatureFunctionDescriptor function;
Parameter *parameter = function.add_parameter();
parameter->set_name("bar");
parameter->set_value("1.75");
TestAttributes attributes;
TF_ASSERT_OK(attributes.Reset(function));
EXPECT_EQ(attributes.foo(), -1);
EXPECT_EQ(attributes.bar(), 1.75);
}
// Tests that a feature chain can be parsed from valid FML, and the feature
// options can then be extracted as attributes.
TEST(ParseFeatureChainFmlTest, ValidFml) {
FeatureFunctionDescriptor leaf;
TF_ASSERT_OK(ParseFeatureChainFml("path.to.feature(foo=123,bar=-0.5)",
{"path", "to", "feature"}, &leaf));
TestAttributes attributes;
TF_ASSERT_OK(attributes.Reset(leaf));
EXPECT_EQ(attributes.foo(), 123);
EXPECT_EQ(attributes.bar(), -0.5);
}
// Tests that an empty feature chain cannot be parsed.
TEST(ParseFeatureChainFmlTest, EmptyChain) {
FeatureFunctionDescriptor leaf;
EXPECT_THAT(ParseFeatureChainFml("foo", {}, &leaf),
test::IsErrorWithSubstr("Empty chain of feature types"));
}
// Tests that empty FML cannot be parsed as a chain.
TEST(ParseFeatureChainFmlTest, EmptyFml) {
FeatureFunctionDescriptor leaf;
EXPECT_THAT(ParseFeatureChainFml("", {"foo"}, &leaf),
test::IsErrorWithSubstr("Failed to parse feature chain"));
}
// Tests that feature chain parsing fails if the chain is too short.
TEST(ParseFeatureChainFmlTest, ChainTooShort) {
FeatureFunctionDescriptor leaf;
EXPECT_THAT(ParseFeatureChainFml("path.to.feature", {"path", "to"}, &leaf),
test::IsErrorWithSubstr("Failed to parse feature chain"));
}
// Tests that feature chain parsing fails if the chain is too long.
TEST(ParseFeatureChainFmlTest, ChainTooLong) {
FeatureFunctionDescriptor leaf;
EXPECT_THAT(ParseFeatureChainFml("path.to", {"path", "to", "feature"}, &leaf),
test::IsErrorWithSubstr("Failed to parse feature chain"));
}
// Tests that initial elements of the chain must match the specified types.
TEST(ParseFeatureChainFmlTest, WrongTypeInPrefix) {
FeatureFunctionDescriptor leaf;
EXPECT_THAT(
ParseFeatureChainFml("path.to.feature", {"bad", "to", "feature"}, &leaf),
test::IsErrorWithSubstr("Failed to parse feature chain"));
}
// Tests that the last feature in the chain must match the specified type.
TEST(ParseFeatureChainFmlTest, WrongTypeInLeaf) {
FeatureFunctionDescriptor leaf;
EXPECT_THAT(
ParseFeatureChainFml("path.to.feature", {"path", "to", "bad"}, &leaf),
test::IsErrorWithSubstr("Failed to parse feature chain"));
}
// Tests that initial elements of the chain cannot have an argument.
TEST(ParseFeatureChainFmlTest, ArgumentInPrefix) {
FeatureFunctionDescriptor leaf;
EXPECT_THAT(
ParseFeatureChainFml("ok.bad(1).leaf", {"ok", "bad", "leaf"}, &leaf),
test::IsErrorWithSubstr("Failed to parse feature chain"));
}
// Tests that initial elements of the chain cannot have an argument.
TEST(ParseFeatureChainFmlTest, OptionInPrefix) {
FeatureFunctionDescriptor leaf;
EXPECT_THAT(
ParseFeatureChainFml("ok.bad(foo=1).leaf", {"ok", "bad", "leaf"}, &leaf),
test::IsErrorWithSubstr("Failed to parse feature chain"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/head_selection_component_base.h"
#include <stddef.h>
#include <algorithm>
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
HeadSelectionComponentBase::HeadSelectionComponentBase(
const string &builder_name, const string &backend_name)
: builder_name_(builder_name), backend_name_(backend_name) {}
bool HeadSelectionComponentBase::Supports(
const ComponentSpec &component_spec,
const string &normalized_builder_name) const {
return (normalized_builder_name == "BulkAnnotatorComponent" ||
normalized_builder_name == builder_name_) &&
(component_spec.backend().registered_name() == "StatelessComponent" ||
component_spec.backend().registered_name() == backend_name_) &&
component_spec.transition_system().registered_name() == "heads" &&
component_spec.network_unit().registered_name() == "IdentityNetwork" &&
component_spec.fixed_feature_size() == 0 &&
component_spec.linked_feature_size() == 1;
}
tensorflow::Status HeadSelectionComponentBase::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
const LinkedFeatureChannel &link = component_spec.linked_feature(0);
size_t dimension = 0;
TF_RETURN_IF_ERROR(network_state_manager->LookupLayer(
link.source_component(), link.source_layer(), &dimension,
&adjacency_handle_));
if (dimension != 1) {
return tensorflow::errors::InvalidArgument(
"Adjacency matrix has dimension ", dimension, " but expected 1");
}
extension_manager->GetShared(&heads_handle_);
return tensorflow::Status::OK();
}
const std::vector<int> &HeadSelectionComponentBase::ComputeHeads(
SessionState *session_state) const {
Matrix<float> adjacency(
session_state->network_states.GetLayer(adjacency_handle_));
std::vector<int> &heads = session_state->extensions.Get(heads_handle_);
heads.resize(adjacency.num_rows());
for (size_t i = 0; i < adjacency.num_rows(); ++i) {
Vector<float> row = adjacency.row(i);
const int head = std::max_element(row.begin(), row.end()) - row.begin();
heads[i] = head != i ? head : -1; // self-loops are roots
}
return heads;
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_HEAD_SELECTION_COMPONENT_BASE_H_
#define DRAGNN_RUNTIME_HEAD_SELECTION_COMPONENT_BASE_H_
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Base class for head-selection components, which select heads independently
// per token. Although this process is not guaranteed to produce a tree, for
// accurate parsers it often produces a tree.
//
// This base class only computes the selected heads, while subclasses apply
// those heads to the annotations in the ComputeSession.
class HeadSelectionComponentBase : public Component {
public:
// Partially implements Component.
bool Supports(const ComponentSpec &component_spec,
const string &normalized_builder_name) const override;
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
bool PreferredTo(const Component &other) const override { return false; }
protected:
// Creates a component that supports the |builder_name| and |backend_name|.
HeadSelectionComponentBase(const string &builder_name,
const string &backend_name);
// Returns the list of heads computed from the |session_state|, where -1
// indicates a root.
const std::vector<int> &ComputeHeads(SessionState *session_state) const;
private:
// Names of the supported component builder and backend.
const string builder_name_;
const string backend_name_;
// Directed adjacency matrix input.
PairwiseLayerHandle<float> adjacency_handle_;
// List of selected head indices.
SharedExtensionHandle<std::vector<int>> heads_handle_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_HEAD_SELECTION_COMPONENT_BASE_H_
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/head_selection_component_base.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr size_t kNumSteps = 12;
constexpr size_t kRootIndex = 7; // the root and head of all other tokens
constexpr char kTestBuilder[] = "TestBuilder";
constexpr char kTestBackend[] = "TestBackend";
constexpr char kPreviousComponentName[] = "previous_component";
constexpr char kAdjacencyLayerName[] = "adjacency_layer";
constexpr char kBadDimLayerName[] = "bad_layer";
// A subclass for tests.
class BasicHeadSelectionComponent : public HeadSelectionComponentBase {
public:
BasicHeadSelectionComponent()
: HeadSelectionComponentBase(kTestBuilder, kTestBackend) {}
// Implements Component. These methods are never called, but must be defined
// so the class is not abstract.
tensorflow::Status Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const override {
return tensorflow::Status::OK();
}
// Publicizes the base class's method.
using HeadSelectionComponentBase::ComputeHeads;
};
// Returns a ComponentSpec that works with the head selection component.
ComponentSpec MakeGoodSpec() {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name(kTestBuilder);
component_spec.mutable_backend()->set_registered_name(kTestBackend);
component_spec.mutable_transition_system()->set_registered_name("heads");
component_spec.mutable_network_unit()->set_registered_name("IdentityNetwork");
LinkedFeatureChannel *link = component_spec.add_linked_feature();
link->set_source_component(kPreviousComponentName);
link->set_source_layer(kAdjacencyLayerName);
return component_spec;
}
class HeadSelectionComponentBaseTest : public NetworkTestBase {
protected:
// Initializes a head selection component from the |component_spec| and sets
// |heads| to the extracted head indices. Returs non-OK on error.
tensorflow::Status Run(const ComponentSpec &component_spec,
std::vector<int> *heads) {
AddComponent(kPreviousComponentName);
AddPairwiseLayer(kAdjacencyLayerName, 1);
AddPairwiseLayer(kBadDimLayerName, 2);
BasicHeadSelectionComponent component;
TF_RETURN_IF_ERROR(component.Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
// Fill the |kRootIndex|'th column of the adjacency matrix with higher
// scores, so all tokens select it as head. The |kRootIndex|'th token
// itself is a self-loop, so it becomes a root.
MutableMatrix<float> adjacency =
GetPairwiseLayer(kPreviousComponentName, kAdjacencyLayerName);
for (size_t target = 0; target < kNumSteps; ++target) {
for (size_t source = 0; source < kNumSteps; ++source) {
adjacency.row(target)[source] = source == kRootIndex ? 1.0 : 0.0;
}
}
session_state_.extensions.Reset(&extension_manager_);
*heads = component.ComputeHeads(&session_state_);
return tensorflow::Status::OK();
}
};
// Tests that the expected heads are produced for a good spec.
TEST_F(HeadSelectionComponentBaseTest, RunsGoodSpec) {
std::vector<int> heads;
TF_ASSERT_OK(Run(MakeGoodSpec(), &heads));
std::vector<int> expected_heads(kNumSteps, kRootIndex);
expected_heads[kRootIndex] = -1;
EXPECT_EQ(heads, expected_heads);
}
// Tests that a layer with the wrong dimension is rejected
TEST_F(HeadSelectionComponentBaseTest, WrongDimension) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(0)->set_source_layer(kBadDimLayerName);
std::vector<int> heads;
EXPECT_THAT(Run(component_spec, &heads),
test::IsErrorWithSubstr(
"Adjacency matrix has dimension 2 but expected 1"));
}
// Tests that the component is always dis-preferred.
TEST_F(HeadSelectionComponentBaseTest, NotPreferred) {
BasicHeadSelectionComponent component;
EXPECT_FALSE(component.PreferredTo(component));
}
// Tests that the good spec is supported.
TEST_F(HeadSelectionComponentBaseTest, SupportsGoodSpec) {
ComponentSpec component_spec = MakeGoodSpec();
BasicHeadSelectionComponent component;
EXPECT_TRUE(component.Supports(component_spec, kTestBuilder));
}
// Tests that various bad specs are rejected.
TEST_F(HeadSelectionComponentBaseTest, RejectsBadSpecs) {
ComponentSpec component_spec = MakeGoodSpec();
BasicHeadSelectionComponent component;
EXPECT_FALSE(component.Supports(component_spec, "bad"));
component_spec = MakeGoodSpec();
component_spec.mutable_backend()->set_registered_name("bad");
EXPECT_FALSE(component.Supports(component_spec, kTestBuilder));
component_spec = MakeGoodSpec();
component_spec.mutable_transition_system()->set_registered_name("bad");
EXPECT_FALSE(component.Supports(component_spec, kTestBuilder));
component_spec = MakeGoodSpec();
component_spec.mutable_network_unit()->set_registered_name("bad");
EXPECT_FALSE(component.Supports(component_spec, kTestBuilder));
component_spec = MakeGoodSpec();
component_spec.add_fixed_feature();
EXPECT_FALSE(component.Supports(component_spec, kTestBuilder));
component_spec = MakeGoodSpec();
component_spec.add_linked_feature();
EXPECT_FALSE(component.Supports(component_spec, kTestBuilder));
component_spec = MakeGoodSpec();
component_spec.clear_linked_feature();
EXPECT_FALSE(component.Supports(component_spec, kTestBuilder));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Applies an identity function.
class IdentitySequenceLinker : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const override;
tensorflow::Status Initialize(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) override;
tensorflow::Status GetLinks(size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const override;
};
bool IdentitySequenceLinker::Supports(
const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const {
TransitionSystemTraits traits(component_spec);
// Note: Add more "||" clauses as needed.
return (channel.fml() == "input.focus" ||
channel.fml() == "char-input.focus") &&
channel.source_translator() == "identity" && traits.is_sequential;
}
tensorflow::Status IdentitySequenceLinker::Initialize(
const LinkedFeatureChannel &channel, const ComponentSpec &component_spec) {
return tensorflow::Status::OK();
}
tensorflow::Status IdentitySequenceLinker::GetLinks(
size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const {
links->resize(source_num_steps);
int32 index = 0;
for (int32 &link : *links) link = index++;
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(IdentitySequenceLinker);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns a ComponentSpec that the linker will support.
ComponentSpec MakeSupportedSpec() {
ComponentSpec component_spec;
component_spec.mutable_transition_system()->set_registered_name("shift-only");
LinkedFeatureChannel *channel = component_spec.add_linked_feature();
channel->set_fml("input.focus");
channel->set_source_translator("identity");
return component_spec;
}
// Tests that the linker supports appropriate specs.
TEST(IdentitySequenceLinkerTest, Supported) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "IdentitySequenceLinker");
channel.set_fml("char-input.focus");
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "IdentitySequenceLinker");
}
// Tests that the linker requires the right transition system.
TEST(IdentitySequenceLinkerTest, WrongTransitionSystem) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
component_spec.mutable_transition_system()->set_registered_name("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right FML.
TEST(IdentitySequenceLinkerTest, WrongFml) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_fml("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right translator.
TEST(IdentitySequenceLinkerTest, WrongTranslator) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_source_translator("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker can be initialized and used to extract links.
TEST(IdentitySequenceLinkerTest, InitializeAndGetLinks) {
const ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
std::unique_ptr<SequenceLinker> linker;
TF_ASSERT_OK(SequenceLinker::New("IdentitySequenceLinker", channel,
component_spec, &linker));
InputBatchCache input;
std::vector<int32> links = {123, 456, 789}; // gets overwritten
TF_ASSERT_OK(linker->GetLinks(10, &input, &links));
const std::vector<int32> expected_links = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
EXPECT_EQ(links, expected_links);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/linked_embeddings.h"
#include <string.h>
#include <algorithm>
#include <utility>
#include "dragnn/protos/data.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/arithmetic.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns the name of the weight matrix for the |channel_id|'th linked feature
// channel of the |component_spec|.
string LinkedWeightMatrixVariableName(const ComponentSpec &component_spec,
int channel_id) {
// Cf. _add_hooks_for_linked_embedding_matrix() in runtime_support.py.
return tensorflow::strings::StrCat(component_spec.name(),
"/linked_embedding_matrix_", channel_id,
"/weights");
}
// As above, but for the out-of-bounds vector.
string LinkedOutOfBoundsVectorVariableName(const ComponentSpec &component_spec,
int channel_id) {
// Cf. _add_hooks_for_linked_embedding_matrix() in runtime_support.py.
return tensorflow::strings::StrCat(component_spec.name(),
"/linked_embedding_matrix_", channel_id,
"/out_of_bounds");
}
} // namespace
tensorflow::Status LinkedEmbeddingManager::Reset(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager) {
const int num_channels = component_spec.linked_feature_size();
std::vector<ChannelConfig> channel_configs(num_channels);
size_t zeros_dimension = 0; // required dimension for the shared zero vector
for (int channel_id = 0; channel_id < num_channels; ++channel_id) {
const LinkedFeatureChannel &channel_spec =
component_spec.linked_feature(channel_id);
ChannelConfig &channel_config = channel_configs[channel_id];
if (channel_spec.size() < 1) {
return tensorflow::errors::InvalidArgument(
"Invalid channel size for channel ", channel_id, ": ",
channel_spec.ShortDebugString());
}
if (channel_spec.size() > 1) {
return tensorflow::errors::Unimplemented(
"Multi-instance linked features are not supported for channel ",
channel_id, ": ", channel_spec.ShortDebugString());
}
size_t source_dimension = 0;
TF_RETURN_IF_ERROR(network_state_manager->LookupLayer(
channel_spec.source_component(), channel_spec.source_layer(),
&source_dimension, &channel_config.source_handle));
channel_config.is_transformed = channel_spec.embedding_dim() >= 0;
if (!channel_config.is_transformed) {
// Out-of-bounds direct links may be pointed at |zeros_|, so it must be
// large enough for any direct link.
channel_config.dimension = source_dimension;
zeros_dimension = std::max(zeros_dimension, channel_config.dimension);
continue;
}
// The remainder of this loop initializes transformed links.
channel_config.dimension = channel_spec.embedding_dim();
TF_RETURN_IF_ERROR(network_state_manager->AddLocal(
channel_config.dimension, &channel_config.product_handle));
const string debug_name = tensorflow::strings::StrCat(
component_spec.name(), ".", channel_spec.name());
TF_RETURN_IF_ERROR(channel_config.weight_matrix.Initialize(
debug_name, LinkedWeightMatrixVariableName(component_spec, channel_id),
channel_spec.embedding_dim(), variable_store));
const FlexibleMatrixKernel &weights = channel_config.weight_matrix;
Vector<float> &out_of_bounds_vector = channel_config.out_of_bounds_vector;
TF_RETURN_IF_ERROR(variable_store->Lookup(
LinkedOutOfBoundsVectorVariableName(component_spec, channel_id),
&out_of_bounds_vector));
if (weights.NumColumns() != source_dimension) {
return tensorflow::errors::InvalidArgument(
"Weight matrix does not match source layer in link ", channel_id,
": weights=[", weights.NumPaddedRows(), ", ", weights.NumColumns(),
"] vs layer_dim=", source_dimension);
}
if (!weights.MatchesOutputDimension(channel_config.dimension)) {
return tensorflow::errors::InvalidArgument(
"Weight matrix shape should be output dimension plus padding. ",
"Linked channel ID: ", channel_id, ": weights=[",
weights.NumPaddedRows(), ", ", weights.NumColumns(),
"] vs output=", channel_config.dimension);
}
if (out_of_bounds_vector.size() != channel_config.dimension) {
return tensorflow::errors::InvalidArgument(
"Out-of-bounds vector does not match embedding_dim in link ",
channel_id, ": out_of_bounds=[", out_of_bounds_vector.size(),
"] vs embedding_dim=", channel_config.dimension);
}
}
// Success; make modifications.
component_name_ = component_spec.name();
channel_configs_ = std::move(channel_configs);
zeros_.Resize(zeros_dimension * sizeof(float));
memset(zeros_.view().data(), 0, zeros_.view().size());
return tensorflow::Status::OK();
}
tensorflow::Status LinkedEmbeddings::Reset(
const LinkedEmbeddingManager *manager, const NetworkStates &network_states,
ComputeSession *compute_session) {
const int num_channels = manager->channel_configs_.size();
channels_.resize(num_channels);
for (int channel_id = 0; channel_id < num_channels; ++channel_id) {
Channel &channel = channels_[channel_id];
const std::vector<LinkFeatures> features =
compute_session->GetTranslatedLinkFeatures(manager->component_name(),
channel_id);
// Since we require LinkedFeatureChannel.size==1, there should be exactly
// one linked feature.
if (features.size() != 1) {
return tensorflow::errors::Internal(
"Got ", features.size(), " linked features; expected 1 for channel ",
channel_id);
}
const LinkFeatures &feature = features[0];
if (feature.batch_idx() > 0) {
return tensorflow::errors::Unimplemented(
"Batches are not supported for channel ", channel_id);
}
if (feature.beam_idx() > 0) {
return tensorflow::errors::Unimplemented(
"Beams are not supported for channel ", channel_id);
}
const int source_beam_size = compute_session->SourceComponentBeamSize(
manager->component_name(), channel_id);
if (source_beam_size != 1) {
return tensorflow::errors::Unimplemented(
"Source beams are not supported for channel ", channel_id);
}
// Consider these bits of the TF-based DRAGNN codebase:
// 1. The ExtractLinkFeatures op in dragnn_op_kernels.cc substitutes -1
// for missing step indices, and clips all step indices to a min of -1.
// 2. activation_lookup_*() in network_units.py adds +1 to step indices.
// 3. Layer.create_array() in network_units.py starts each TensorArray
// with a zero vector.
// Therefore, a direct link with a missing or negative step index should
// receive a zeroed embedding. Regarding transformed links:
// 4. NetworkUnitInterface.__init__() in network_units.py extends the
// linked embedding matrix by 1 row.
// 5. pass_through_embedding_matrix() in network_units.py extends each
// input activation vector with a 0/1 out-of-bounds indicator.
// The result of multiplying the extended linked embedding matrix with the
// extended input activation vector is:
// * If in-bounds: The product of the non-extended matrix and vector.
// * If out-of-bounds: The last row of the extended matrix.
const bool is_out_of_bounds =
!feature.has_step_idx() || feature.step_idx() < 0;
channel.is_out_of_bounds = is_out_of_bounds;
const LinkedEmbeddingManager::ChannelConfig &channel_config =
manager->channel_configs_[channel_id];
if (is_out_of_bounds) {
if (channel_config.is_transformed) {
// Point at the special out-of-bounds embedding.
channel.embedding = channel_config.out_of_bounds_vector;
} else {
// Point at a prefix of the zero vector.
//
// TODO(googleuser): Consider providing is_zero(channel_id)
// so we can elide ops on zero vectors later on in the pipeline. This
// would help if out-of-bounds links are frequent.
channel.embedding =
Vector<float>(manager->zeros_.view(), channel_config.dimension);
}
} else {
// Point at the activation vector of the translated step index.
channel.embedding = network_states.GetLayer(channel_config.source_handle)
.row(feature.step_idx());
if (channel_config.is_transformed) {
// Multiply with the weight matrix and point at the result.
const MutableVector<float> product =
network_states.GetLocal(channel_config.product_handle);
channel_config.weight_matrix.MatrixVectorProduct(channel.embedding,
product);
channel.embedding = product;
}
}
DCHECK_EQ(channel.embedding.size(), channel_config.dimension);
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for configuring and extracting linked embeddings.
//
// A linked embedding is a reference to an output layer produced by a source
// component. If the source component and receiving component are the same,
// then the link is recurrent.
//
// A linked embedding can be "direct" or "transformed". A direct link does not
// modify the source activation vectors, and maps an out-of-bounds access to a
// zero vector. A transformed link multiplies the source activation vectors by
// a weight matrix, and maps an out-of-bounds access to a special vector.
#ifndef DRAGNN_RUNTIME_LINKED_EMBEDDINGS_H_
#define DRAGNN_RUNTIME_LINKED_EMBEDDINGS_H_
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/flexible_matrix_kernel.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/variable_store.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A class that manages a set of linked embeddings for some component. The
// embeddings can be extracted using LinkedEmbeddings, defined below.
class LinkedEmbeddingManager {
public:
// Creates an empty manager.
LinkedEmbeddingManager() = default;
// Resets this to the linked embeddings specified by the |component_spec|.
// Retrieves transformation variables from the |variable_store|, which must
// outlive this. Looks up linked embeddings in the |network_state_manager|,
// which must be positioned at the current component and must contain any
// layers intended for recurrent access. Also adds local operands to the
// |network_state_manager|. Channel ordering follows the |component_spec|.
// On error, returns non-OK and does not modify this.
tensorflow::Status Reset(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager);
// Accessors.
const string &component_name() const { return component_name_; }
size_t num_channels() const { return channel_configs_.size(); }
size_t embedding_dim(size_t channel_id) const;
size_t num_embeddings() const { return num_channels(); }
private:
friend class LinkedEmbeddings;
friend class SequenceLinkManager;
// Configuration for a single linked embedding channel. Several fields are
// only used by transformed links.
struct ChannelConfig {
// Size of the embedding vectors in this channel.
size_t dimension = 0;
// Handle of the source layer containing the linked embedding.
LayerHandle<float> source_handle;
// Whether this is a transformed link. The fields below are only populated
// and used if this is true.
bool is_transformed = false;
// Weight matrix and out-of-bounds embedding vector for transformed links.
FlexibleMatrixKernel weight_matrix;
Vector<float> out_of_bounds_vector;
// Handle of the local vector containing the product of the |weights| and
// the source activation vector.
LocalVectorHandle<float> product_handle;
};
// Name of the component receiving the linked embeddings.
string component_name_;
// Ordered list of configurations for each channel.
std::vector<ChannelConfig> channel_configs_;
// Array of zeros that can be substituted for any embedding vector, in the
// case that the step index is out of range. Only used by non-transformed
// linked embeddings.
UniqueAlignedArray zeros_;
};
// A set of linked embeddings, configured via the LinkedEmbeddingManager.
class LinkedEmbeddings {
public:
// Creates an empty set of embeddings.
LinkedEmbeddings() = default;
// Resets this to the embeddings managed by the |manager|. Translates linked
// features using the |compute_session| and retrieves embedding vectors from
// the |network_states|, which must both be positioned at the component whose
// embeddings are managed by the |manager|. The |manager| must live until
// this is destroyed or Reset(), and should not be modified during that time.
// On error, returns non-OK.
tensorflow::Status Reset(const LinkedEmbeddingManager *manager,
const NetworkStates &network_states,
ComputeSession *compute_session);
// Accessors.
size_t num_embeddings() const { return channels_.size(); }
Vector<float> embedding(size_t channel_id) const;
bool is_out_of_bounds(size_t channel_id) const;
private:
// Data associated with a single linked embedding channel.
struct Channel {
// Linked embedding vector for the channel.
Vector<float> embedding;
// Whether the embedding is out-of-bounds.
bool is_out_of_bounds = false;
};
// Ordered list of linked embedding channels.
std::vector<Channel> channels_;
};
// Implementation details below.
inline size_t LinkedEmbeddingManager::embedding_dim(size_t channel_id) const {
return channel_configs_[channel_id].dimension;
}
inline Vector<float> LinkedEmbeddings::embedding(size_t channel_id) const {
return channels_[channel_id].embedding;
}
inline bool LinkedEmbeddings::is_out_of_bounds(size_t channel_id) const {
return channels_[channel_id].is_out_of_bounds;
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_LINKED_EMBEDDINGS_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/linked_embeddings.h"
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::_;
using ::testing::Invoke;
using ::testing::Return;
// Dimensions of the layers in the network (see ResetManager() below).
const size_t kPrevious1LayerDim = 16;
const size_t kPrevious2LayerDim = 32;
const size_t kRecurrentLayerDim = 48;
// Dimensions of the transformed links in the network.
const size_t kPrevious2EmbeddingDim = 24;
const size_t kRecurrentEmbeddingDim = 40;
// Number of transition steps to take in each component in the network.
const size_t kNumSteps = 10;
// A working one-channel ComponentSpec.
const char kSingleSpec[] = R"(linked_feature {
embedding_dim: -1
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
})";
// A working multi-channel ComponentSpec.
const char kMultiSpec[] = R"(linked_feature {
embedding_dim: -1
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 1
}
linked_feature {
embedding_dim: 24
source_component: 'source_component_2'
source_layer: 'previous_2'
size: 1
}
linked_feature {
embedding_dim: 40
source_component: 'test_component'
source_layer: 'recurrent'
size: 1
})";
class LinkedEmbeddingManagerTest : public NetworkTestBase {
protected:
// Creates a LinkedEmbeddingManager and returns the result of Reset()-ing it
// using the |component_spec_text|.
tensorflow::Status ResetManager(const string &component_spec_text) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
component_spec.set_name(kTestComponentName);
AddComponent("source_component_0");
AddComponent("source_component_1");
AddLayer("previous_1", kPrevious1LayerDim);
AddComponent("source_component_2");
AddLayer("previous_2", kPrevious2LayerDim);
AddComponent(kTestComponentName);
AddLayer("recurrent", kRecurrentLayerDim);
return manager_.Reset(component_spec, &variable_store_,
&network_state_manager_);
}
LinkedEmbeddingManager manager_;
};
// Tests that LinkedEmbeddingManager is empty by default.
TEST_F(LinkedEmbeddingManagerTest, EmptyByDefault) {
EXPECT_EQ(manager_.num_channels(), 0);
EXPECT_EQ(manager_.num_embeddings(), 0);
}
// Tests that LinkedEmbeddingManager is empty when reset to an empty spec.
TEST_F(LinkedEmbeddingManagerTest, EmptySpec) {
TF_EXPECT_OK(ResetManager(""));
EXPECT_EQ(manager_.component_name(), kTestComponentName);
EXPECT_EQ(manager_.num_channels(), 0);
EXPECT_EQ(manager_.num_embeddings(), 0);
}
// Tests that LinkedEmbeddingManager works with a single channel.
TEST_F(LinkedEmbeddingManagerTest, OneChannel) {
TF_EXPECT_OK(ResetManager(kSingleSpec));
EXPECT_EQ(manager_.component_name(), kTestComponentName);
EXPECT_EQ(manager_.num_channels(), 1);
EXPECT_EQ(manager_.embedding_dim(0), kPrevious1LayerDim);
EXPECT_EQ(manager_.num_embeddings(), 1);
}
// Tests that LinkedEmbeddingManager works with multiple channels.
TEST_F(LinkedEmbeddingManagerTest, MultipleChannels) {
AddLinkedWeightMatrix(1, kPrevious2LayerDim, kPrevious2EmbeddingDim, 0.0);
AddLinkedWeightMatrix(2, kRecurrentLayerDim, kRecurrentEmbeddingDim, 0.0);
AddLinkedOutOfBoundsVector(1, kPrevious2EmbeddingDim, 0.0);
AddLinkedOutOfBoundsVector(2, kRecurrentEmbeddingDim, 0.0);
TF_EXPECT_OK(ResetManager(kMultiSpec));
EXPECT_EQ(manager_.component_name(), kTestComponentName);
EXPECT_EQ(manager_.num_channels(), 3);
EXPECT_EQ(manager_.embedding_dim(0), kPrevious1LayerDim);
EXPECT_EQ(manager_.embedding_dim(1), kPrevious2EmbeddingDim);
EXPECT_EQ(manager_.embedding_dim(2), kRecurrentEmbeddingDim);
EXPECT_EQ(manager_.num_embeddings(), 3);
}
// Tests that LinkedEmbeddingManager fails when the channel size is 0.
TEST_F(LinkedEmbeddingManagerTest, InvalidChannelSize) {
const string kBadSpec = R"(linked_feature {
embedding_dim: -1
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 0 # bad
})";
EXPECT_THAT(ResetManager(kBadSpec),
test::IsErrorWithSubstr("Invalid channel size"));
}
// Tests that LinkedEmbeddingManager fails when the channel size is > 1.
TEST_F(LinkedEmbeddingManagerTest, UnsupportedChannelSize) {
const string kBadSpec = R"(linked_feature {
embedding_dim: -1
source_component: 'source_component_1'
source_layer: 'previous_1'
size: 2 # bad
})";
EXPECT_THAT(ResetManager(kBadSpec),
test::IsErrorWithSubstr(
"Multi-instance linked features are not supported"));
}
// Tests that LinkedEmbeddingManager fails when the source component is unknown.
TEST_F(LinkedEmbeddingManagerTest, UnknownComponent) {
const string kBadSpec = R"(linked_feature {
embedding_dim: -1
source_component: 'missing_component' # bad
source_layer: 'previous_1'
size: 1
})";
EXPECT_THAT(ResetManager(kBadSpec),
test::IsErrorWithSubstr("Unknown component"));
}
// Tests that LinkedEmbeddingManager fails when the source layer is unknown.
TEST_F(LinkedEmbeddingManagerTest, UnknownLayer) {
const string kBadSpec = R"(linked_feature {
embedding_dim: -1
source_component: 'source_component_0'
source_layer: 'missing_layer' # bad
size: 1
})";
EXPECT_THAT(ResetManager(kBadSpec),
test::IsErrorWithSubstr("Unknown layer"));
}
// Tests that LinkedEmbeddingManager fails for a missing weight matrix.
TEST_F(LinkedEmbeddingManagerTest, MissingWeightMatrix) {
// Only the weight matrix for channel 2 is missing.
AddLinkedWeightMatrix(1, kPrevious2LayerDim, kPrevious2EmbeddingDim, 0.0);
AddLinkedOutOfBoundsVector(1, kPrevious2EmbeddingDim, 0.0);
AddLinkedOutOfBoundsVector(2, kRecurrentEmbeddingDim, 0.0);
EXPECT_THAT(ResetManager(kMultiSpec),
test::IsErrorWithSubstr("Unknown variable"));
}
// Tests that LinkedEmbeddingManager fails for a missing out-of-bounds vector.
TEST_F(LinkedEmbeddingManagerTest, MissingOutOfBoundsVector) {
// Only the out-of-bounds vector for channel 1 is missing.
AddLinkedWeightMatrix(1, kPrevious2LayerDim, kPrevious2EmbeddingDim, 0.0);
AddLinkedWeightMatrix(2, kRecurrentLayerDim, kRecurrentEmbeddingDim, 0.0);
AddLinkedOutOfBoundsVector(2, kRecurrentEmbeddingDim, 0.0);
EXPECT_THAT(ResetManager(kMultiSpec),
test::IsErrorWithSubstr("Unknown variable"));
}
// Tests that LinkedEmbeddingManager fails for a weight matrix with the wrong
// number of rows.
TEST_F(LinkedEmbeddingManagerTest, WeightMatrixRowMismatch) {
AddLinkedWeightMatrix(1, kPrevious2LayerDim + 1, kPrevious2EmbeddingDim, 0.0);
AddLinkedWeightMatrix(2, kRecurrentLayerDim, kRecurrentEmbeddingDim, 0.0);
AddLinkedOutOfBoundsVector(1, kPrevious2EmbeddingDim, 0.0);
AddLinkedOutOfBoundsVector(2, kRecurrentEmbeddingDim, 0.0);
EXPECT_THAT(ResetManager(kMultiSpec),
test::IsErrorWithSubstr(
"Weight matrix does not match source layer in link 1"));
}
// Tests that LinkedEmbeddingManager fails for a weight matrix with the wrong
// number of columns.
TEST_F(LinkedEmbeddingManagerTest, WeightMatrixColumnMismatch) {
AddLinkedWeightMatrix(1, kPrevious2LayerDim, kPrevious2EmbeddingDim, 0.0);
AddLinkedWeightMatrix(2, kRecurrentLayerDim, kRecurrentEmbeddingDim - 1, 0.0);
AddLinkedOutOfBoundsVector(1, kPrevious2EmbeddingDim, 0.0);
AddLinkedOutOfBoundsVector(2, kRecurrentEmbeddingDim, 0.0);
EXPECT_THAT(ResetManager(kMultiSpec),
test::IsErrorWithSubstr(
"Weight matrix shape should be output dimension plus "
"padding"));
}
// Tests that LinkedEmbeddingManager fails for a weight matrix with the wrong
// number of rows.
TEST_F(LinkedEmbeddingManagerTest, OutOfBoundsVectorSizeMismatch) {
AddLinkedWeightMatrix(1, kPrevious2LayerDim, kPrevious2EmbeddingDim, 0.0);
AddLinkedWeightMatrix(2, kRecurrentLayerDim, kRecurrentEmbeddingDim, 0.0);
AddLinkedOutOfBoundsVector(1, kPrevious2EmbeddingDim + 1, 0.0);
AddLinkedOutOfBoundsVector(2, kRecurrentEmbeddingDim, 0.0);
EXPECT_THAT(
ResetManager(kMultiSpec),
test::IsErrorWithSubstr(
"Out-of-bounds vector does not match embedding_dim in link 1"));
}
// Values to fill each layer with.
const float kLayerValues[] = {1.0, 2.0, 3.0};
class LinkedEmbeddingsTest : public LinkedEmbeddingManagerTest {
protected:
// Resets the |linked_embeddings_| using the |manager_|, |network_states_|,
// and |compute_session_|, and returns the resulting status.
tensorflow::Status ResetLinkedEmbeddings() {
network_states_.Reset(&network_state_manager_);
// Fill components with steps.
StartComponent(kNumSteps); // source_component_0
StartComponent(kNumSteps); // source_component_1
StartComponent(kNumSteps); // source_component_2
StartComponent(kNumSteps); // current component
// Fill layers with values.
FillLayer("source_component_1", "previous_1", kLayerValues[0]);
FillLayer("source_component_2", "previous_2", kLayerValues[1]);
FillLayer(kTestComponentName, "recurrent", kLayerValues[2]);
return linked_embeddings_.Reset(&manager_, network_states_,
&compute_session_);
}
LinkedEmbeddings linked_embeddings_;
};
// Tests that LinkedEmbeddings is empty by default.
TEST_F(LinkedEmbeddingsTest, EmptyByDefault) {
EXPECT_EQ(linked_embeddings_.num_embeddings(), 0);
}
// Tests that LinkedEmbeddings is empty when reset by an empty manager.
TEST_F(LinkedEmbeddingsTest, EmptyManager) {
TF_ASSERT_OK(ResetManager(""));
TF_EXPECT_OK(ResetLinkedEmbeddings());
EXPECT_EQ(linked_embeddings_.num_embeddings(), 0);
}
// Tests that LinkedEmbeddings fails when no linked features are extracted.
TEST_F(LinkedEmbeddingsTest, OneChannelNoFeatures) {
TF_ASSERT_OK(ResetManager(kSingleSpec));
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
EXPECT_THAT(ResetLinkedEmbeddings(),
test::IsErrorWithSubstr("Got 0 linked features; expected 1"));
}
// Tests that LinkedEmbeddings works when exactly one linked feature is
// extracted.
TEST_F(LinkedEmbeddingsTest, OneChannelOneFeature) {
TF_ASSERT_OK(ResetManager(kSingleSpec));
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: 5"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
TF_ASSERT_OK(ResetLinkedEmbeddings());
ASSERT_EQ(linked_embeddings_.num_embeddings(), 1);
ExpectVector(linked_embeddings_.embedding(0), kPrevious1LayerDim, 1.0);
EXPECT_FALSE(linked_embeddings_.is_out_of_bounds(0));
}
// Tests that LinkedEmbeddings fails when more than one linked feature is
// extracted.
TEST_F(LinkedEmbeddingsTest, OneChannelManyFeatures) {
TF_ASSERT_OK(ResetManager(kSingleSpec));
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(
ExtractLinks(0, {"step_idx: 5", "step_idx: 6", "step_idx: 7"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
EXPECT_THAT(ResetLinkedEmbeddings(),
test::IsErrorWithSubstr("Got 3 linked features; expected 1"));
}
// Tests that LinkedEmbeddings fails if the linked feature has a batch index.
TEST_F(LinkedEmbeddingsTest, BatchesUnsupported) {
TF_ASSERT_OK(ResetManager(kSingleSpec));
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: 5 batch_idx: 1"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
EXPECT_THAT(ResetLinkedEmbeddings(),
test::IsErrorWithSubstr("Batches are not supported"));
}
// Tests that LinkedEmbeddings fails if the linked feature has a beam index.
TEST_F(LinkedEmbeddingsTest, BeamsUnsupported) {
TF_ASSERT_OK(ResetManager(kSingleSpec));
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: 5 beam_idx: 1"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
EXPECT_THAT(ResetLinkedEmbeddings(),
test::IsErrorWithSubstr("Beams are not supported"));
}
// Tests that LinkedEmbeddings fails if the source component of the link has
// beam size > 1.
TEST_F(LinkedEmbeddingsTest, OneChannelWithSourceBeam) {
TF_ASSERT_OK(ResetManager(kSingleSpec));
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: 5"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillOnce(Return(2));
EXPECT_THAT(ResetLinkedEmbeddings(),
test::IsErrorWithSubstr("Source beams are not supported"));
}
// Tests that LinkedEmbeddings produces zeros when the extracted linked feature
// has no step index.
TEST_F(LinkedEmbeddingsTest, OneChannelNoStep) {
TF_ASSERT_OK(ResetManager(kSingleSpec));
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {""})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
TF_ASSERT_OK(ResetLinkedEmbeddings());
ASSERT_EQ(linked_embeddings_.num_embeddings(), 1);
ExpectVector(linked_embeddings_.embedding(0), kPrevious1LayerDim, 0.0);
EXPECT_TRUE(linked_embeddings_.is_out_of_bounds(0));
}
// Tests that LinkedEmbeddings produces zeros when the extracted linked feature
// has step index -1.
TEST_F(LinkedEmbeddingsTest, OneChannelNegativeOneStep) {
TF_ASSERT_OK(ResetManager(kSingleSpec));
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: -1"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
TF_ASSERT_OK(ResetLinkedEmbeddings());
ASSERT_EQ(linked_embeddings_.num_embeddings(), 1);
ExpectVector(linked_embeddings_.embedding(0), kPrevious1LayerDim, 0.0);
EXPECT_TRUE(linked_embeddings_.is_out_of_bounds(0));
}
// Tests that LinkedEmbeddings produces zeros when the extracted linked feature
// has a large negative step index.
TEST_F(LinkedEmbeddingsTest, OneChannelLargeNegativeStep) {
TF_ASSERT_OK(ResetManager(kSingleSpec));
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: -100"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.WillRepeatedly(Return(1));
TF_ASSERT_OK(ResetLinkedEmbeddings());
ASSERT_EQ(linked_embeddings_.num_embeddings(), 1);
ExpectVector(linked_embeddings_.embedding(0), kPrevious1LayerDim, 0.0);
EXPECT_TRUE(linked_embeddings_.is_out_of_bounds(0));
}
// Tests that LinkedEmbeddings works with multiple linked channels.
TEST_F(LinkedEmbeddingsTest, ManyChannels) {
AddLinkedWeightMatrix(1, kPrevious2LayerDim, kPrevious2EmbeddingDim, 0.5);
AddLinkedWeightMatrix(2, kRecurrentLayerDim, kRecurrentEmbeddingDim, 1.5);
AddLinkedOutOfBoundsVector(1, kPrevious2EmbeddingDim, 5.5);
AddLinkedOutOfBoundsVector(2, kRecurrentEmbeddingDim, 7.75);
const size_t kEmbeddingDims[] = {kPrevious1LayerDim, //
kPrevious2EmbeddingDim, //
kRecurrentEmbeddingDim};
const float kExpected[] = {kLayerValues[0], //
kLayerValues[1] * kPrevious2LayerDim * 0.5f, //
kLayerValues[2] * kRecurrentLayerDim * 1.5f};
TF_ASSERT_OK(ResetManager(kMultiSpec));
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: 5"})))
.WillOnce(Invoke(ExtractLinks(1, {"step_idx: 6"})))
.WillOnce(Invoke(ExtractLinks(2, {"step_idx: 7"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.Times(3)
.WillRepeatedly(Return(1));
TF_ASSERT_OK(ResetLinkedEmbeddings());
ASSERT_EQ(linked_embeddings_.num_embeddings(), 3);
for (int channel_id = 0; channel_id < linked_embeddings_.num_embeddings();
++channel_id) {
ExpectVector(linked_embeddings_.embedding(channel_id),
kEmbeddingDims[channel_id], kExpected[channel_id]);
EXPECT_FALSE(linked_embeddings_.is_out_of_bounds(channel_id));
}
}
// Tests that LinkedEmbeddings produces the relevant out-of-bounds embeddings
// when multiple linked channels have invalid step indices.
TEST_F(LinkedEmbeddingsTest, ManyChannelsOutOfBounds) {
AddLinkedWeightMatrix(1, kPrevious2LayerDim, kPrevious2EmbeddingDim, 0.5);
AddLinkedWeightMatrix(2, kRecurrentLayerDim, kRecurrentEmbeddingDim, 1.5);
AddLinkedOutOfBoundsVector(1, kPrevious2EmbeddingDim, 5.5);
AddLinkedOutOfBoundsVector(2, kRecurrentEmbeddingDim, 7.75);
const size_t kEmbeddingDims[] = {kPrevious1LayerDim, //
kPrevious2EmbeddingDim, //
kRecurrentEmbeddingDim};
const float kExpected[] = {0.0f, 5.5f, 7.75f};
TF_ASSERT_OK(ResetManager(kMultiSpec));
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.WillOnce(Invoke(ExtractLinks(0, {"step_idx: -1"})))
.WillOnce(Invoke(ExtractLinks(1, {"step_idx: -10"})))
.WillOnce(Invoke(ExtractLinks(2, {"step_idx: -999"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.Times(3)
.WillRepeatedly(Return(1));
TF_ASSERT_OK(ResetLinkedEmbeddings());
ASSERT_EQ(linked_embeddings_.num_embeddings(), 3);
for (int channel_id = 0; channel_id < linked_embeddings_.num_embeddings();
++channel_id) {
ExpectVector(linked_embeddings_.embedding(channel_id),
kEmbeddingDims[channel_id], kExpected[channel_id]);
EXPECT_TRUE(linked_embeddings_.is_out_of_bounds(channel_id));
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
# Contains functions related to implementing the LSTM cell. Split out into a
# folder because we will probably add different test harnesses, data, etc.
package(
default_visibility = ["//visibility:public"],
)
load(
"@org_tensorflow//tensorflow:tensorflow.bzl",
"if_linux_x86_64",
)
load(
"//dragnn/runtime:multiarch.bzl",
"dragnn_cc_multiarch_library",
"dragnn_cc_multiarch_test",
"dragnn_cc_multiarch_binary",
)
FAST_MATH_COPTS = if_linux_x86_64([
# Note: Without masking, -O3 is significantly faster.
"-O3",
"-msse4.2",
"-ffast-math",
])
dragnn_cc_multiarch_library(
name = "cell_function",
srcs = ["cell_function.cc"],
hdrs = ["cell_function.h"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
"//dragnn/runtime/math:avx_activation_functions",
"//dragnn/runtime/math:avx_vector_array",
"//dragnn/runtime/math:sgemvv",
"//dragnn/runtime/math:types",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_library(
name = "test_helpers",
testonly = 1,
hdrs = ["test_helpers.h"],
deps = [
":cell_function",
"//dragnn/runtime/math:float16_types",
"//dragnn/runtime/math:sgemvv",
"//dragnn/runtime/test:helpers",
],
)
dragnn_cc_multiarch_test(
name = "cell_function_test",
srcs = ["cell_function_test.cc"],
deps = [
":cell_function",
":test_helpers",
"//dragnn/core/test:generic",
"//dragnn/runtime/math:arithmetic",
"//dragnn/runtime/math:transformations",
"//dragnn/runtime/test:helpers",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_binary(
name = "cell_function_benchmark",
testonly = 1,
srcs = ["cell_function_benchmark.cc"],
deps = [
":cell_function",
":test_helpers",
"//dragnn/runtime/math:transformations",
"//dragnn/runtime/test:helpers",
"//syntaxnet:base",
],
)
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/lstm_cell/cell_function.h"
#if defined(__SSE2__)
#include <xmmintrin.h>
#endif
#include "dragnn/runtime/math/avx_activation_functions.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
template <class T>
void PrefetchVector(Vector<T> vector) {
#if defined(__SSE2__)
constexpr size_t kPrefetchStride = 64 / sizeof(T);
for (int i = 0; i < vector.size(); i += kPrefetchStride) {
_mm_prefetch(vector.data() + i, _MM_HINT_T1);
}
#endif
}
// Calls the single-vector instance of SGEMV with output masking. (See SGEMVV
// documentation for |lookahead_1| and |lookahead_2| semantics.
template <int lookahead_1, int lookahead_2, class MatrixType>
void CellMatrixVector(const MatrixType &matrix, Vector<float> input,
Vector<float> initial, MutableVector<float> output) {
SgemvInputBatch<1> inputs{{input.data()}, {initial.data()}};
SgemvOutputBatch<1> outputs{{output.data()}};
const bool use_optimized =
output.size() % LstmCellFunction<>::kBatchSize == 0;
if (use_optimized) {
matrix.template MatrixMultiVectorProduct<1, lookahead_1, lookahead_2>(
inputs, &outputs);
} else {
matrix.template MaskedMatrixMultiVectorProduct<1, lookahead_1, lookahead_2>(
inputs, output.size(), &outputs);
}
}
// Calls the single-vector instance of SGEMV with output masking, adding to an
// existing vector (partial sum). (See SGEMVV documentation for |lookahead_1|
// and |lookahead_2| semantics.
template <int lookahead_1, int lookahead_2, typename MatrixType>
void CellMatrixVector(const MatrixType &matrix, Vector<float> input,
MutableVector<float> initial_and_output) {
CellMatrixVector<lookahead_1, lookahead_2>(
matrix, input, Vector<float>(initial_and_output), initial_and_output);
}
// Internal helper function for applying an n-ary function element-wise to
// vectors. We could make it more user-friendly by using a special type
// generator for `indices`, but by taking it explicitly the implementation is
// simpler. Also, public API helpers are easier to interact with.
template <int batch_size, class Function, int... indices>
void ApplyVariadic(const Function &fcn, int size,
Vector<float> inputs[sizeof...(indices)],
MutableVector<float> output) {
for (int start = 0; start < size; start += batch_size) {
const int load_store_max_idx = (size - start) / kAvxWidth;
AvxFloatVecArray<batch_size / kAvxWidth> arrays[sizeof...(indices)];
for (int i = 0; i < sizeof...(indices); ++i) {
// NOTE: This calls .data() to skip debug size checks; it is generally
// OK to prefetch a bit too far ahead.
_mm_prefetch(&inputs[i].data()[start + batch_size], _MM_HINT_T0);
arrays[i].Load(&inputs[i][start], load_store_max_idx);
}
for (int i = 0; i < batch_size / kAvxWidth; i++) {
// We store the result to a random input cell. The choice of the first is
// actually inconsequential; all we're going to do is write it out later.
arrays[0].vectors[i] = fcn(arrays[indices].vectors[i]...);
}
arrays[0].Store(&output[start], load_store_max_idx);
}
}
// Apply a unary function on one vector, modifying its contents.
template <int batch_size, class Function>
void ApplyUnary(const Function &fcn, MutableVector<float> vector) {
Vector<float> inputs[] = {Vector<float>(vector)};
ApplyVariadic<batch_size, Function, 0>(fcn, vector.size(), inputs, vector);
}
// Apply a binary function on two vectors, storing the result in a (possibly
// separate) output.
template <int batch_size, class Function>
void ApplyBinary(const Function &fcn, Vector<float> arg_1, Vector<float> arg_2,
MutableVector<float> result) {
Vector<float> inputs[] = {arg_1, arg_2};
ApplyVariadic<batch_size, Function, 0, 1>(fcn, result.size(), inputs, result);
}
template <int batch_size, class Function>
void ApplyTrinary(const Function &fcn, Vector<float> arg_1, Vector<float> arg_2,
Vector<float> arg_3, MutableVector<float> result) {
Vector<float> inputs[] = {arg_1, arg_2, arg_3};
ApplyVariadic<batch_size, Function, 0, 1, 2>(fcn, result.size(), inputs,
result);
}
AvxFloatVec InitialCellStateFunction(AvxFloatVec cell_input,
AvxFloatVec cell_state_partial_sum) {
return AvxFloatVec(cell_input * activations::Tanh(cell_state_partial_sum));
}
AvxFloatVec CellStateFunction(AvxFloatVec cell_input,
AvxFloatVec last_cell_state,
AvxFloatVec cell_state_partial_sum) {
AvxFloatVec dot_tanh(cell_input * activations::Tanh(cell_state_partial_sum));
return (AvxFloatVec::Const(1.0) - cell_input) * last_cell_state + dot_tanh;
}
AvxFloatVec HiddenStateFunction(AvxFloatVec cell_output,
AvxFloatVec cell_state) {
return AvxFloatVec(cell_output * activations::Tanh(cell_state));
}
} // namespace
#define DRAGNN_RETURN_IF_NOT_EQUAL(actual_size, expected_size) \
if ((actual_size) != (expected_size)) { \
return tensorflow::errors::InvalidArgument( \
"Vector/matrix size " #actual_size " (", (actual_size), \
") does not " \
"match expected size " #expected_size " (", \
(expected_size), ")"); \
}
template <typename MatrixElementType>
tensorflow::Status LstmCellFunction<MatrixElementType>::Initialize(
int hidden_size, Vector<float> cell_input_state_output_bias,
SgemvMatrix<kBatchSize, MatrixElementType> input_to_cell_input_state_output,
SgemvMatrix<kBatchSize, MatrixElementType>
last_hidden_to_cell_input_state_output,
SgemvMatrix<kBatchSize, MatrixElementType> last_cell_state_to_cell_input,
SgemvMatrix<kBatchSize, MatrixElementType> cell_state_to_cell_output) {
if (hidden_size % kAvxWidth != 0) {
return tensorflow::errors::InvalidArgument(
"Expected hidden size (", hidden_size,
") to be a multiple of the AVX width (", kAvxWidth, ")");
}
auto pad_rows = [](size_t size) {
return kBatchSize * ((size + kBatchSize - 1) / kBatchSize);
};
DRAGNN_RETURN_IF_NOT_EQUAL(cell_input_state_output_bias.size(),
3 * hidden_size);
DRAGNN_RETURN_IF_NOT_EQUAL(
input_to_cell_input_state_output.matrix().num_rows(),
pad_rows(3 * hidden_size));
DRAGNN_RETURN_IF_NOT_EQUAL(
last_hidden_to_cell_input_state_output.matrix().num_rows(),
pad_rows(3 * hidden_size));
DRAGNN_RETURN_IF_NOT_EQUAL(
last_hidden_to_cell_input_state_output.matrix().num_columns(),
hidden_size);
DRAGNN_RETURN_IF_NOT_EQUAL(last_cell_state_to_cell_input.matrix().num_rows(),
pad_rows(hidden_size));
DRAGNN_RETURN_IF_NOT_EQUAL(
last_cell_state_to_cell_input.matrix().num_columns(), hidden_size);
DRAGNN_RETURN_IF_NOT_EQUAL(cell_state_to_cell_output.matrix().num_rows(),
pad_rows(hidden_size));
DRAGNN_RETURN_IF_NOT_EQUAL(cell_state_to_cell_output.matrix().num_columns(),
hidden_size);
hidden_size_ = hidden_size;
cell_input_state_output_bias_ = cell_input_state_output_bias;
input_to_cell_input_state_output_ = input_to_cell_input_state_output;
last_hidden_to_cell_input_state_output_ =
last_hidden_to_cell_input_state_output;
last_cell_state_to_cell_input_ = last_cell_state_to_cell_input;
cell_state_to_cell_output_ = cell_state_to_cell_output;
return tensorflow::Status::OK();
}
template <typename MatrixElementType>
tensorflow::Status LstmCellFunction<MatrixElementType>::RunInputComputations(
const Matrix<float> inputs,
const MutableMatrix<float> cell_input_temps) const {
DRAGNN_RETURN_IF_NOT_EQUAL(inputs.num_rows(), cell_input_temps.num_rows());
DRAGNN_RETURN_IF_NOT_EQUAL(
inputs.num_columns(),
input_to_cell_input_state_output_.matrix().num_columns());
DRAGNN_RETURN_IF_NOT_EQUAL(cell_input_temps.num_columns(), 3 * hidden_size_);
const bool use_optimized = (3 * hidden_size_) % kBatchSize == 0;
// Pair each input with its neighbor, and run SGEMVV.
SgemvInputBatch<2> sgemvv_inputs;
SgemvOutputBatch<2> sgemvv_outputs;
for (int i = 0; i + 1 < inputs.num_rows(); i += 2) {
for (int op = 0; op < 2; ++op) {
sgemvv_inputs.input[op] = inputs.row(i + op).data();
sgemvv_inputs.initial[op] = cell_input_state_output_bias_.data();
sgemvv_outputs.output[op] = cell_input_temps.row(i + op).data();
}
if (use_optimized) {
input_to_cell_input_state_output_
.template MatrixMultiVectorProduct<2, 8, 8>(sgemvv_inputs,
&sgemvv_outputs);
} else {
input_to_cell_input_state_output_
.template MaskedMatrixMultiVectorProduct<2, 8, 8>(
sgemvv_inputs, 3 * hidden_size_, &sgemvv_outputs);
}
}
// Odd-sized inputs need an additional SGEMV operation.
if (inputs.num_rows() % 2 != 0) {
const int i = inputs.num_rows() - 1;
SgemvInputBatch<1> sgemvv_inputs;
SgemvOutputBatch<1> sgemvv_outputs;
sgemvv_inputs.input[0] = inputs.row(i).data();
sgemvv_inputs.initial[0] = cell_input_state_output_bias_.data();
sgemvv_outputs.output[0] = cell_input_temps.row(i).data();
if (use_optimized) {
input_to_cell_input_state_output_
.template MatrixMultiVectorProduct<1, 8, 8>(sgemvv_inputs,
&sgemvv_outputs);
} else {
input_to_cell_input_state_output_
.template MaskedMatrixMultiVectorProduct<1, 8, 8>(
sgemvv_inputs, 3 * hidden_size_, &sgemvv_outputs);
}
}
return tensorflow::Status::OK();
}
template <typename MatrixElementType>
tensorflow::Status LstmCellFunction<MatrixElementType>::RunRecurrentComputation(
bool is_initial, Vector<float> last_hidden, Vector<float> last_cell_state,
MutableVector<float> cell_input_temp, MutableVector<float> cell_state,
MutableVector<float> cell_output, MutableVector<float> next_hidden) const {
// Check input sizes.
if (!is_initial) {
DRAGNN_RETURN_IF_NOT_EQUAL(last_hidden.size(), hidden_size_);
DRAGNN_RETURN_IF_NOT_EQUAL(last_cell_state.size(), hidden_size_);
}
DRAGNN_RETURN_IF_NOT_EQUAL(cell_input_temp.size(), 3 * hidden_size_);
DRAGNN_RETURN_IF_NOT_EQUAL(cell_state.size(), hidden_size_);
DRAGNN_RETURN_IF_NOT_EQUAL(cell_output.size(), hidden_size_);
DRAGNN_RETURN_IF_NOT_EQUAL(next_hidden.size(), hidden_size_);
#undef DRAGNN_RETURN_IF_NOT_EQUAL
MutableVector<float> cell_input =
cell_input_temp.Subsequence(0, hidden_size_);
Vector<float> cell_state_partial_sum(
cell_input_temp.Subsequence(hidden_size_, hidden_size_));
Vector<float> cell_output_partial_sum(
cell_input_temp.Subsequence(2 * hidden_size_, hidden_size_));
if (!is_initial) {
PrefetchVector(last_cell_state);
CellMatrixVector<16, 0>(last_hidden_to_cell_input_state_output_,
last_hidden, cell_input_temp);
CellMatrixVector<1, 0>(last_cell_state_to_cell_input_, last_cell_state,
cell_input);
}
ApplyUnary<24>(activations::Sigmoid, cell_input);
// Computes cell state,
//
// $c_t = f_t \cdot c_{t-1} + i_t \cdot tanh([x2c] x_t + [h2c] h_{t-1} + b_c)$
//
// where $f_t = 1 - i_t$.
if (is_initial) {
ApplyBinary<32>(InitialCellStateFunction, Vector<float>(cell_input),
cell_state_partial_sum, cell_state);
} else {
ApplyTrinary<16>(CellStateFunction, Vector<float>(cell_input),
last_cell_state, cell_state_partial_sum, cell_state);
}
// Computes cell output,
//
// $o_t = \sigma([x2o] x_t + [h2o] h_{t-1} + [c2o] c_t + b_o)$
//
// where all but the $c_t$ component of the affine transformation have already
// been computed by the composite "ico" matrices above.
CellMatrixVector<0, 0>(cell_state_to_cell_output_, Vector<float>(cell_state),
cell_output_partial_sum, cell_output);
ApplyUnary<24>(activations::Sigmoid, cell_output);
// Computes the hidden state,
//
// $h_t = o_t \cdot tanh(c_t)$
ApplyBinary<16>(HiddenStateFunction, Vector<float>(cell_output),
Vector<float>(cell_state), next_hidden);
return tensorflow::Status::OK();
}
template <typename MatrixElementType>
double LstmCellFunction<MatrixElementType>::FlopsPerRun(bool is_initial) const {
double sum = 0;
for (const auto &matrix :
{input_to_cell_input_state_output_, cell_state_to_cell_output_}) {
sum += (2 * matrix.matrix().num_rows() * matrix.matrix().num_columns());
}
if (!is_initial) {
for (const auto &matrix : {last_hidden_to_cell_input_state_output_,
last_cell_state_to_cell_input_}) {
sum += (2 * matrix.matrix().num_rows() * matrix.matrix().num_columns());
}
}
// Element-wise activation calculations.
sum += (26 + // i_t sigmoid
26 + // c_t tanh (23) plus 3 more
26 + // o_t sigmoid
24 // h_t tanh and multiplication
) *
hidden_size_;
return sum;
}
// Instantiate the class for floats and TruncatedFloat16's.
template class LstmCellFunction<float>;
template class LstmCellFunction<TruncatedFloat16>;
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_LSTM_CELL_CELL_FUNCTION_H_
#define DRAGNN_RUNTIME_LSTM_CELL_CELL_FUNCTION_H_
#include "dragnn/runtime/math/avx_vector_array.h"
#include "dragnn/runtime/math/sgemvv.h"
#include "dragnn/runtime/math/types.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Interface for either type of LSTM cell function. Initialization is
// type-dependent, so not included in the shared interface.
class LstmCellFunctionBase {
public:
virtual ~LstmCellFunctionBase() {}
// Runs the LSTM cell. |is_initial| indicates whether this is the first run.
// |input| is the embedded feature vector (sometimes denoted "x"),
// |last_hidden| is the last hidden state, denoted h_{t-1} (null/invalid when
// |is_initial| is True), and similarly |last_cell_state| is the previous cell
// state, denoted c_{t-1}.
//
// The caller must allocate the temporary |cell_input_temp|, which must be
// 3 * hidden_size; the first |hidden_size| values will be the cell input
// vector (which is typically not used externally). |cell_state|,
// |cell_output|, and |next_hidden| must be |hidden_size|-length vectors.
//
// Returns InvalidArgument errors if any of the vector sizes are not expected.
tensorflow::Status Run(bool is_initial, Vector<float> input,
Vector<float> last_hidden,
Vector<float> last_cell_state,
MutableVector<float> cell_input_temp,
MutableVector<float> cell_state,
MutableVector<float> cell_output,
MutableVector<float> next_hidden) const {
TF_RETURN_IF_ERROR(RunInputComputations(
Matrix<float>(input), MutableMatrix<float>(cell_input_temp)));
return RunRecurrentComputation(is_initial, last_hidden, last_cell_state,
cell_input_temp, cell_state, cell_output,
next_hidden);
}
// Runs the LSTM cell input computations.
//
// |inputs| constains vectors of embedded feature vectors (sometimes denoted
// "x"). The caller must allocate the temporary |cell_input_temps|, each of
// which must be 3 * hidden_size.
//
// Returns InvalidArgument errors if any of the vector sizes are not expected.
virtual tensorflow::Status RunInputComputations(
Matrix<float> inputs, MutableMatrix<float> cell_input_temps) const = 0;
// Runs the recurrent part of the LSTM cell.
//
// |is_initial| indicates whether this is the first run. The temporary
// |cell_input_temp| must be from RunInputComputation(), |last_hidden| is the
// last hidden state, denoted h_{t-1} (null/invalid when |is_initial| is
// True), and similarly |last_cell_state| is the previous cell state, denoted
// c_{t-1}.
//
// |cell_state|, |cell_output|, and |next_hidden| must be |hidden_size|-length
// vectors.
//
// Returns InvalidArgument errors if any of the vector sizes are not expected.
virtual tensorflow::Status RunRecurrentComputation(
bool is_initial, Vector<float> last_hidden, Vector<float> last_cell_state,
MutableVector<float> cell_input_temp, MutableVector<float> cell_state,
MutableVector<float> cell_output,
MutableVector<float> next_hidden) const = 0;
// Returns the number of floating-point operations necessary for one run. This
// is typically dominated by matrix-vector-multiply operations, which use 2 *
// width * height floating point operations.
virtual double FlopsPerRun(bool is_initial) const = 0;
};
// Helper class which computes the LSTM function. This is a separate class from
// the network unit so that its performance can be tested and tuned separately.
template <typename MatrixElementType = float>
class LstmCellFunction : public LstmCellFunctionBase {
public:
// Batch size for SGEMV matrices. It's probably OK to use one batch size,
// because we concatenate [x2i, x2c, x2o], etc. matrices so there is less
// inefficiency from batching.
static constexpr int kBatchSize = 48;
// Public type alias for the underlying matrix type.
using MatrixType = SgemvMatrix<kBatchSize, MatrixElementType>;
LstmCellFunction() = default;
// Instantiates a LSTM cell function.
//
// Pass the following vectors and matrices,
//
// * |cell_input_state_output_bias| - Concatenated bias terms for cell input
// (typically denoted `i`), cell state (denoted `c`), and cell output
// (denoted `o`).
// * |input_to_cell_input_state_output| - A matrix which will compute partial
// sums of cell input, state, and output expressions, given the input
// vector `x`. This is the concatenation of [x2i], [x2c], and [x2o]
// matrices in the Python network builder code.
// * |last_hidden_to_cell_input_state_output| - Likewise, computes partial
// sums given the last hidden state.
// * |last_cell_state_to_cell_input| - Used to compute partial sum of cell
// input, given *previous* cell state.
// * |cell_state_to_cell_output| - Used to compute partial sum of cell
// output, given *current* cell state.
//
// Returns an InvalidArgument error if hidden_size is not a multiple of the
// AVX width (currently 8). This is used to reduce copying slightly, but is
// not an essential optimization.
tensorflow::Status Initialize(
int hidden_size, Vector<float> cell_input_state_output_bias,
MatrixType input_to_cell_input_state_output,
MatrixType last_hidden_to_cell_input_state_output,
MatrixType last_cell_state_to_cell_input,
MatrixType cell_state_to_cell_output);
// Implements LstmCellFunctionBase.
tensorflow::Status RunInputComputations(
Matrix<float> inputs,
MutableMatrix<float> cell_input_temps) const override;
tensorflow::Status RunRecurrentComputation(
bool is_initial, Vector<float> last_hidden, Vector<float> last_cell_state,
MutableVector<float> cell_input_temp, MutableVector<float> cell_state,
MutableVector<float> cell_output,
MutableVector<float> next_hidden) const override;
double FlopsPerRun(bool is_initial) const override;
private:
// Hidden layer size.
int hidden_size_;
// Concatenated bias terms for cell input (typically denoted `i`), cell state
// (denoted `c`), and cell output (denoted `o`).
Vector<float> cell_input_state_output_bias_;
// A matrix which will compute partial sums of cell input, state, and output
// expressions, given the input vector `x`. This is the concatenation of
// [x2i], [x2c], and [x2o] matrices in the Python network builder code.
MatrixType input_to_cell_input_state_output_;
// Likewise, computes partial sums given the last hidden state.
MatrixType last_hidden_to_cell_input_state_output_;
// Used to compute partial sum of cell input, given *previous* cell state.
MatrixType last_cell_state_to_cell_input_;
// Used to compute partial sum of cell output, given *current* cell state.
MatrixType cell_state_to_cell_output_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_LSTM_CELL_CELL_FUNCTION_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/lstm_cell/cell_function.h"
#include <cmath>
#include <chrono>
#include <iostream>
#include <random>
#include <tuple>
#include "dragnn/runtime/lstm_cell/test_helpers.h"
#include "dragnn/runtime/math/transformations.h"
#include "dragnn/runtime/test/helpers.h"
#include "syntaxnet/base.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace benchmark {
// Covenience aliases, since we always use the same batch size.
using CellMatrix = SgemvMatrix<LstmCellFunction<>::kBatchSize>;
// This class allocates matrices and vectors contiguously, in the order they
// were requested. It estimates the storage necessary from the beginning, and
// CHECK-fails if this is insufficient. Ergo it should only be used for
// benchmarking.
class CoherentStorage : public VectorMatrixStorage {
public:
CoherentStorage() {
constexpr int kMaxHiddenSize = 256;
// This should be enough, though could be improved by factoring in input
// size. Please adjust this class if it is not.
array_.Resize(10 * sizeof(float) *
ComputeAlignedAreaSize(kMaxHiddenSize, kMaxHiddenSize));
}
MutableVector<float> RandomVector(int size) override {
auto view = GetNextView(size);
return MutableVector<float>(view, size);
}
protected:
MutableBlockedMatrix<float, BlockedMatrixFormat::kRowBlockedColumnMajor>
RandomBlockedMatrix(int rows, int columns, int batch_size) override {
int rows_padded = batch_size * ((rows + batch_size - 1) / batch_size);
int num_views = rows_padded * columns / batch_size;
auto view = GetNextView(num_views * batch_size);
MutableAlignedArea area;
TF_CHECK_OK(area.Reset(view, num_views, batch_size * sizeof(float)));
// Set random values. It doesn't matter that the rows/cols aren't what we
// output.
InitRandomMatrix(MutableMatrix<float>(area));
// Construct SGEMV matrix types.
MutableBlockedMatrix<float, BlockedMatrixFormat::kRowBlockedColumnMajor>
blocked;
TF_CHECK_OK(blocked.Reset(area, rows_padded, columns));
return blocked;
}
private:
// Gets the next view, where |size| is the number of floats desired.
MutableAlignedView GetNextView(size_t size) {
size_t size_bytes = PadToAlignment(size * sizeof(float));
MutableAlignedView view;
TF_CHECK_OK(view.Reset(&array_.view().data()[next_offset_], size_bytes));
next_offset_ += size_bytes;
CHECK_LE(next_offset_, array_.view().size());
return view;
}
UniqueAlignedArray array_;
// Next offset to return.
int next_offset_ = 0;
};
template <class StorageClass = VectorMatrixStorage>
void LstmCellBenchmark(int hidden_size, int input_dimension, bool is_initial) {
// RAII storage for vectors and matrices.
StorageClass storage;
// Helper function. Because StorageClass is template, we need to call
// templated member functions with the `template` keyword as well, which gets
// verbose.
auto random_matrix = [&storage](int rows, int columns) {
return storage.template RandomMatrix<LstmCellFunction<>::kBatchSize>(
rows, columns);
};
// Parameters for the LSTM cell, and for one run. We allocate them together
// so that it's easy to experiment with more coherent initialization schemes.
MutableVector<float> cell_input_state_output_bias =
storage.RandomVector(3 * hidden_size);
CellMatrix input_to_cell_input_state_output =
random_matrix(3 * hidden_size, input_dimension);
Vector<float> input(storage.RandomVector(input_dimension));
MutableVector<float> cell_input_temp = storage.RandomVector(3 * hidden_size);
CellMatrix last_hidden_to_cell_input_state_output =
random_matrix(3 * hidden_size, hidden_size);
Vector<float> last_hidden(storage.RandomVector(hidden_size));
CellMatrix last_cell_state_to_cell_input =
random_matrix(hidden_size, hidden_size);
Vector<float> last_cell_state(storage.RandomVector(hidden_size));
MutableVector<float> cell_state = storage.RandomVector(hidden_size);
CellMatrix cell_state_to_cell_output =
random_matrix(hidden_size, hidden_size);
MutableVector<float> cell_output = storage.RandomVector(hidden_size);
MutableVector<float> next_hidden = storage.RandomVector(hidden_size);
// TODO(googleuser): Benchmark with different matrix element types.
LstmCellFunction<float> cell;
TF_CHECK_OK(cell.Initialize(
hidden_size, Vector<float>(cell_input_state_output_bias),
input_to_cell_input_state_output, last_hidden_to_cell_input_state_output,
last_cell_state_to_cell_input, cell_state_to_cell_output));
double flops_per_run = cell.FlopsPerRun(is_initial);
auto start_time = std::chrono::system_clock::now();
int kIterations = static_cast<int>(10e9 / flops_per_run);
for (int i = 0; i < kIterations; ++i) {
TF_CHECK_OK(cell.Run(is_initial, input, last_hidden, last_cell_state,
cell_input_temp, cell_state, cell_output,
next_hidden));
}
auto end_time = std::chrono::system_clock::now();
std::chrono::duration<double> elapsed_seconds = end_time - start_time;
double elapsed = elapsed_seconds.count();
double flops = flops_per_run * kIterations;
std::cerr << "Cell with hidden=" << hidden_size
<< ", input_dimension=" << input_dimension
<< ", is_initial=" << is_initial
<< " kflops/run=" << std::round(flops_per_run / 1e3)
<< ", average GFLOPS=" << flops / 1e9 / elapsed << std::endl;
}
enum class Subcomputation { kAll, kInputOnly, kRecurrentOnly };
template <class StorageClass = VectorMatrixStorage,
Subcomputation computation = Subcomputation::kAll>
void LstmCellMultiTokenBenchmark(int hidden_size, int input_dimension,
int tokens_per_sentence) {
std::cerr << "Document benchmark with hidden=" << hidden_size
<< ", input_dimension=" << input_dimension
<< ", tokens_per_sentence=" << tokens_per_sentence;
// RAII storage for vectors and matrices.
StorageClass storage;
MatrixParameters parameters;
parameters.Init(hidden_size, input_dimension, &storage);
// Parameters for one run of the LSTM cell.
UniqueMatrix<float> inputs(tokens_per_sentence, input_dimension);
UniqueMatrix<float> cell_input_temps(tokens_per_sentence, 3 * hidden_size);
UniqueMatrix<float> hiddens(tokens_per_sentence, hidden_size);
InitRandomMatrix(*inputs);
InitRandomMatrix(*cell_input_temps);
InitRandomMatrix(*hiddens);
MutableVector<float> cell_state = storage.RandomVector(hidden_size);
MutableVector<float> cell_output = storage.RandomVector(hidden_size);
// TODO(googleuser): Benchmark with different matrix element types.
LstmCellFunction<float> cell;
TF_CHECK_OK(parameters.InitializeCell(&cell));
// There is 1 iniital state and n-1 non-initial states.
double input_flops =
tokens_per_sentence * 2.0 * (3 * hidden_size) * input_dimension;
double flops_per_run = cell.FlopsPerRun(true) +
(tokens_per_sentence - 1) * cell.FlopsPerRun(false);
if (computation == Subcomputation::kInputOnly) {
flops_per_run = input_flops;
} else if (computation == Subcomputation::kRecurrentOnly) {
flops_per_run -= input_flops;
}
auto start_time = std::chrono::system_clock::now();
int kIterations = static_cast<int>(10e9 / flops_per_run);
for (int iter = 0; iter < kIterations; ++iter) {
// SGEMVV input to [cell input, cell state, cell output] computation.
if (computation != Subcomputation::kRecurrentOnly) {
TF_CHECK_OK(
cell.RunInputComputations(Matrix<float>(*inputs), *cell_input_temps));
}
// Run recurrent parts of the network.
if (computation != Subcomputation::kInputOnly) {
for (int i = 0; i < tokens_per_sentence; ++i) {
Vector<float> last_cell_state;
Vector<float> last_hidden;
if (i != 0) {
last_cell_state = Vector<float>(cell_state);
last_hidden = Vector<float>(hiddens->row(i - 1));
}
TF_CHECK_OK(cell.RunRecurrentComputation(
i == 0, last_hidden, last_cell_state, cell_input_temps->row(i),
cell_state, cell_output, hiddens->row(i)));
}
}
}
auto end_time = std::chrono::system_clock::now();
std::chrono::duration<double> elapsed_seconds = end_time - start_time;
double elapsed = elapsed_seconds.count();
double flops = flops_per_run * kIterations;
std::cerr << " kflops/run=" << std::round(flops_per_run / 1e3)
<< ", average GFLOPS=" << flops / 1e9 / elapsed << std::endl;
}
} // namespace benchmark
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
using syntaxnet::dragnn::runtime::VectorMatrixStorage;
using syntaxnet::dragnn::runtime::benchmark::CoherentStorage;
using syntaxnet::dragnn::runtime::benchmark::LstmCellBenchmark;
using syntaxnet::dragnn::runtime::benchmark::LstmCellMultiTokenBenchmark;
using syntaxnet::dragnn::runtime::benchmark::Subcomputation;
int main(int argc, char **argv) {
LstmCellBenchmark(64, 32, false);
LstmCellBenchmark(96, 32, false);
LstmCellBenchmark(128, 32, false);
LstmCellBenchmark(256, 32, false);
std::cerr << std::endl << "With coherent memory:" << std::endl;
LstmCellBenchmark<CoherentStorage>(64, 32, false);
LstmCellBenchmark<CoherentStorage>(96, 32, false);
LstmCellBenchmark<CoherentStorage>(128, 32, false);
LstmCellBenchmark<CoherentStorage>(256, 32, false);
// These are used for tuning coefficients in cell_function.cc.
std::cerr << std::endl;
LstmCellMultiTokenBenchmark(48, 32, 10);
LstmCellMultiTokenBenchmark(64, 32, 5);
LstmCellMultiTokenBenchmark(64, 32, 10);
LstmCellMultiTokenBenchmark(96, 96, 2);
LstmCellMultiTokenBenchmark(96, 96, 5);
LstmCellMultiTokenBenchmark(96, 96, 10);
LstmCellMultiTokenBenchmark(96, 96, 20);
LstmCellMultiTokenBenchmark(128, 32, 2);
LstmCellMultiTokenBenchmark(128, 32, 5);
LstmCellMultiTokenBenchmark(128, 32, 10);
LstmCellMultiTokenBenchmark(128, 32, 20);
LstmCellMultiTokenBenchmark(128, 128, 10);
LstmCellMultiTokenBenchmark(144, 32, 10);
LstmCellMultiTokenBenchmark(256, 32, 10);
std::cerr << std::endl
<< "Input computation only (similar to sgemvv_test):" << std::endl;
LstmCellMultiTokenBenchmark<VectorMatrixStorage, Subcomputation::kInputOnly>(
96, 96, 2);
LstmCellMultiTokenBenchmark<VectorMatrixStorage, Subcomputation::kInputOnly>(
96, 96, 10);
LstmCellMultiTokenBenchmark<VectorMatrixStorage, Subcomputation::kInputOnly>(
96, 96, 20);
std::cerr << std::endl << "Recurrent computation only:" << std::endl;
LstmCellMultiTokenBenchmark<VectorMatrixStorage,
Subcomputation::kRecurrentOnly>(96, 96, 2);
LstmCellMultiTokenBenchmark<VectorMatrixStorage,
Subcomputation::kRecurrentOnly>(96, 96, 10);
LstmCellMultiTokenBenchmark<VectorMatrixStorage,
Subcomputation::kRecurrentOnly>(96, 96, 20);
std::cerr << std::endl << "With coherent memory:" << std::endl;
LstmCellMultiTokenBenchmark<CoherentStorage>(48, 32, 10);
LstmCellMultiTokenBenchmark<CoherentStorage>(64, 32, 10);
LstmCellMultiTokenBenchmark<CoherentStorage>(96, 32, 10);
LstmCellMultiTokenBenchmark<CoherentStorage>(128, 32, 10);
LstmCellMultiTokenBenchmark<CoherentStorage>(144, 32, 10);
return 0;
}
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/lstm_cell/cell_function.h"
#include <cmath>
#include <chrono>
#include <iostream>
#include <random>
#include <tuple>
#include "dragnn/core/test/generic.h"
#include "dragnn/runtime/lstm_cell/test_helpers.h"
#include "dragnn/runtime/math/arithmetic.h"
#include "dragnn/runtime/math/transformations.h"
#include "dragnn/runtime/test/helpers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Covenience aliases, since we always use the same batch size.
constexpr int kBatchSize = LstmCellFunction<>::kBatchSize;
using CellMatrix = SgemvMatrix<kBatchSize>;
// Un-optimized version of the LSTM cell. Practically the same API except the
// constructor takes size arguments.
class UnoptimizedCellFunction {
public:
UnoptimizedCellFunction(int hidden_size, int input_size)
: hidden_size_(hidden_size),
input_to_cell_input_state_output_(3 * hidden_size, input_size),
last_hidden_to_cell_input_state_output_(3 * hidden_size, hidden_size),
last_cell_state_to_cell_input_(hidden_size, hidden_size),
cell_state_to_cell_output_(hidden_size, hidden_size) {}
tensorflow::Status Initialize(
int hidden_size, Vector<float> cell_input_state_output_bias,
CellMatrix input_to_cell_input_state_output,
CellMatrix last_hidden_to_cell_input_state_output,
CellMatrix last_cell_state_to_cell_input,
CellMatrix cell_state_to_cell_output) {
cell_input_state_output_bias_ = cell_input_state_output_bias;
// Copies a padded SGEMV matrix to a non-padded regular matrix.
auto copy_matrix_to_unpadded = [&](CellMatrix input,
MutableMatrix<float> output) {
CopyMatrixPrefix(input.matrix(), output.num_rows(), output.num_columns(),
&output);
};
copy_matrix_to_unpadded(input_to_cell_input_state_output,
*input_to_cell_input_state_output_);
copy_matrix_to_unpadded(last_hidden_to_cell_input_state_output,
*last_hidden_to_cell_input_state_output_);
copy_matrix_to_unpadded(last_cell_state_to_cell_input,
*last_cell_state_to_cell_input_);
copy_matrix_to_unpadded(cell_state_to_cell_output,
*cell_state_to_cell_output_);
return tensorflow::Status::OK();
}
tensorflow::Status Run(bool is_initial, Vector<float> input,
Vector<float> last_hidden,
Vector<float> last_cell_state,
MutableVector<float> cell_input_temp,
MutableVector<float> cell_state,
MutableVector<float> cell_output,
MutableVector<float> next_hidden) {
MutableVector<float> cell_input =
cell_input_temp.Subsequence(0, hidden_size_);
MultiplyMatrixAndVectorWithBias(
Matrix<float>(*input_to_cell_input_state_output_),
cell_input_state_output_bias_, input, cell_input_temp);
if (!is_initial) {
MultiplyMatrixAndVectorWithBias(
Matrix<float>(*last_hidden_to_cell_input_state_output_),
Vector<float>(cell_input_temp), last_hidden, cell_input_temp);
MultiplyMatrixAndVectorWithBias(
Matrix<float>(*last_cell_state_to_cell_input_),
Vector<float>(cell_input), last_cell_state, cell_input);
}
// Apply sigmoid (using cmath).
for (int i = 0; i < hidden_size_; ++i) {
cell_input[i] = 1.0 / (1.0 + exp(-cell_input[i]));
}
// Cell state.
for (int i = 0; i < hidden_size_; ++i) {
if (is_initial) {
cell_state[i] = cell_input[i] * tanh(cell_input_temp[hidden_size_ + i]);
} else {
float forget = 1.0f - cell_input[i];
// Recall cell_input_temp[hidden_size_ + i] is the i'th value of
// the partial sum [x2c] * x_t + [h2c] * h_{t-1} + b_c.
cell_state[i] =
(forget * last_cell_state[i]) +
(cell_input[i] * tanh(cell_input_temp[hidden_size_ + i]));
}
}
// Cell output.
auto cell_output_partial_sum =
cell_input_temp.Subsequence(2 * hidden_size_, hidden_size_);
MultiplyMatrixAndVectorWithBias(Matrix<float>(*cell_state_to_cell_output_),
Vector<float>(cell_output_partial_sum),
Vector<float>(cell_state), cell_output);
for (int i = 0; i < hidden_size_; ++i) {
cell_output[i] = 1.0 / (1.0 + exp(-cell_output[i]));
}
// Hidden state.
for (int i = 0; i < hidden_size_; ++i) {
next_hidden[i] = cell_output[i] * tanh(cell_state[i]);
}
return tensorflow::Status::OK();
}
private:
int hidden_size_;
Vector<float> cell_input_state_output_bias_;
UniqueMatrix<float> input_to_cell_input_state_output_;
UniqueMatrix<float> last_hidden_to_cell_input_state_output_;
UniqueMatrix<float> last_cell_state_to_cell_input_;
UniqueMatrix<float> cell_state_to_cell_output_;
};
TEST(CellFunctionTest, TestInitializeErrors) {
int hidden_size = 128;
int input_dimension = 32;
// RAII storage for vectors and matrices.
VectorMatrixStorage storage;
// LSTM cell.
Vector<float> cell_input_state_output_bias(
storage.RandomVector(3 * hidden_size));
CellMatrix input_to_cell_input_state_output =
storage.RandomMatrix<kBatchSize>(3 * hidden_size, input_dimension);
CellMatrix last_hidden_to_cell_input_state_output =
storage.RandomMatrix<kBatchSize>(3 * hidden_size, hidden_size);
CellMatrix last_cell_state_to_cell_input =
storage.RandomMatrix<kBatchSize>(hidden_size, hidden_size);
CellMatrix cell_state_to_cell_output =
storage.RandomMatrix<kBatchSize>(hidden_size, hidden_size);
LstmCellFunction<float> cell;
EXPECT_THAT(cell.Initialize(
hidden_size, Vector<float>(storage.RandomVector(hidden_size)),
input_to_cell_input_state_output,
last_hidden_to_cell_input_state_output,
last_cell_state_to_cell_input, cell_state_to_cell_output),
test::IsErrorWithSubstr(
"Vector/matrix size cell_input_state_output_bias.size() (128)"
" does not match expected size 3 * "
"hidden_size (384)"));
EXPECT_THAT(
cell.Initialize(
hidden_size, cell_input_state_output_bias,
storage.RandomMatrix<kBatchSize>(hidden_size, input_dimension),
last_hidden_to_cell_input_state_output, last_cell_state_to_cell_input,
cell_state_to_cell_output),
test::IsErrorWithSubstr("Vector/matrix size "
"input_to_cell_input_state_output.matrix().num_"
"rows() "
"(144) does not match expected size pad_rows(3 * "
"hidden_size) (384)"));
EXPECT_THAT(cell.Initialize(
hidden_size, cell_input_state_output_bias,
input_to_cell_input_state_output,
storage.RandomMatrix<kBatchSize>(hidden_size, hidden_size),
last_cell_state_to_cell_input, cell_state_to_cell_output),
test::IsErrorWithSubstr(
"Vector/matrix size "
"last_hidden_to_cell_input_state_output.matrix().num_rows() "
"(144) does "
"not match expected size pad_rows(3 * hidden_size) (384)"));
EXPECT_THAT(
cell.Initialize(
hidden_size, cell_input_state_output_bias,
input_to_cell_input_state_output,
storage.RandomMatrix<kBatchSize>(3 * hidden_size, 2 * hidden_size),
last_cell_state_to_cell_input, cell_state_to_cell_output),
test::IsErrorWithSubstr("Vector/matrix size "
"last_hidden_to_cell_input_state_output.matrix()."
"num_columns() (256) does not "
"match expected size hidden_size (128)"));
EXPECT_THAT(
cell.Initialize(
hidden_size, cell_input_state_output_bias,
input_to_cell_input_state_output,
last_hidden_to_cell_input_state_output,
storage.RandomMatrix<kBatchSize>(2 * hidden_size, hidden_size),
cell_state_to_cell_output),
test::IsErrorWithSubstr(
"Vector/matrix size "
"last_cell_state_to_cell_input.matrix().num_rows() (288) does "
"not match expected size pad_rows(hidden_size) (144)"));
EXPECT_THAT(
cell.Initialize(
hidden_size, cell_input_state_output_bias,
input_to_cell_input_state_output,
last_hidden_to_cell_input_state_output,
storage.RandomMatrix<kBatchSize>(hidden_size, 2 * hidden_size),
cell_state_to_cell_output),
test::IsErrorWithSubstr("Vector/matrix size "
"last_cell_state_to_cell_input.matrix().num_"
"columns() (256) does "
"not match expected size hidden_size (128)"));
EXPECT_THAT(
cell.Initialize(
hidden_size, cell_input_state_output_bias,
input_to_cell_input_state_output,
last_hidden_to_cell_input_state_output, last_cell_state_to_cell_input,
storage.RandomMatrix<kBatchSize>(2 * hidden_size, hidden_size)),
test::IsErrorWithSubstr(
"Vector/matrix size cell_state_to_cell_output.matrix().num_rows() "
"(288) does not match expected size "
"pad_rows(hidden_size) (144)"));
EXPECT_THAT(
cell.Initialize(
hidden_size, cell_input_state_output_bias,
input_to_cell_input_state_output,
last_hidden_to_cell_input_state_output, last_cell_state_to_cell_input,
storage.RandomMatrix<kBatchSize>(hidden_size, 2 * hidden_size)),
test::IsErrorWithSubstr(
"Vector/matrix size "
"cell_state_to_cell_output.matrix().num_columns() (256) does not "
"match expected size hidden_size (128)"));
}
TEST(CellFunctionTest, TestRunErrors) {
int hidden_size = 128;
int input_dimension = 32;
// RAII storage for vectors and matrices.
VectorMatrixStorage storage;
// LSTM cell.
Vector<float> cell_input_state_output_bias(
storage.RandomVector(3 * hidden_size));
CellMatrix input_to_cell_input_state_output =
storage.RandomMatrix<kBatchSize>(3 * hidden_size, input_dimension);
CellMatrix last_hidden_to_cell_input_state_output =
storage.RandomMatrix<kBatchSize>(3 * hidden_size, hidden_size);
CellMatrix last_cell_state_to_cell_input =
storage.RandomMatrix<kBatchSize>(hidden_size, hidden_size);
CellMatrix cell_state_to_cell_output =
storage.RandomMatrix<kBatchSize>(hidden_size, hidden_size);
// Per-run inputs.
Vector<float> input(storage.RandomVector(input_dimension));
Vector<float> last_hidden(storage.RandomVector(hidden_size));
Vector<float> last_cell_state(storage.RandomVector(hidden_size));
MutableVector<float> cell_input_temp = storage.RandomVector(3 * hidden_size);
MutableVector<float> cell_state = storage.RandomVector(hidden_size);
MutableVector<float> cell_output = storage.RandomVector(hidden_size);
MutableVector<float> next_hidden = storage.RandomVector(hidden_size);
LstmCellFunction<float> cell;
TF_EXPECT_OK(cell.Initialize(
hidden_size, cell_input_state_output_bias,
input_to_cell_input_state_output, last_hidden_to_cell_input_state_output,
last_cell_state_to_cell_input, cell_state_to_cell_output));
EXPECT_THAT(
cell.Run(true, Vector<float>(storage.RandomVector(input_dimension / 2)),
last_hidden, last_cell_state, cell_input_temp, cell_state,
cell_output, next_hidden),
test::IsErrorWithSubstr("Vector/matrix size inputs.num_columns() (16) "
"does not match expected size "
"input_to_cell_input_state_output_.matrix().num_"
"columns() (32)"));
EXPECT_THAT(cell.Run(false, input,
Vector<float>(storage.RandomVector(2 * hidden_size)),
last_cell_state, cell_input_temp, cell_state,
cell_output, next_hidden),
test::IsErrorWithSubstr("Vector/matrix size last_hidden.size() "
"(256) does not match expected size "
"hidden_size_ (128)"));
EXPECT_THAT(cell.Run(false, input, last_hidden,
Vector<float>(storage.RandomVector(2 * hidden_size)),
cell_input_temp, cell_state, cell_output, next_hidden),
test::IsErrorWithSubstr(
"Vector/matrix size last_cell_state.size() (256) does not "
"match expected size hidden_size_ (128)"));
EXPECT_THAT(cell.Run(true, input, last_hidden, last_cell_state,
storage.RandomVector(hidden_size), cell_state,
cell_output, next_hidden),
test::IsErrorWithSubstr(
"Vector/matrix size cell_input_temps.num_columns() (128) "
"does not match expected size 3 * hidden_size_ (384)"));
EXPECT_THAT(
cell.Run(true, input, last_hidden, last_cell_state, cell_input_temp,
storage.RandomVector(2 * hidden_size), cell_output, next_hidden),
test::IsErrorWithSubstr("Vector/matrix size cell_state.size() (256) does "
"not match expected size hidden_size_ (128)"));
EXPECT_THAT(
cell.Run(true, input, last_hidden, last_cell_state, cell_input_temp,
cell_state, storage.RandomVector(2 * hidden_size), next_hidden),
test::IsErrorWithSubstr("Vector/matrix size cell_output.size() (256) "
"does not match expected size hidden_size_ "
"(128)"));
EXPECT_THAT(
cell.Run(true, input, last_hidden, last_cell_state, cell_input_temp,
cell_state, cell_output, storage.RandomVector(2 * hidden_size)),
test::IsErrorWithSubstr("Vector/matrix size next_hidden.size() (256) "
"does not match expected size hidden_size_ "
"(128)"));
}
// Test harness, with parameters hidden_size, input_dimension, and is_initial.
class CellFuzzTest
: public ::testing::TestWithParam<std::tuple<int, int, bool>> {};
TEST_P(CellFuzzTest, TestMatchesNaiveAlgorithm) {
int hidden_size;
int input_dimension;
bool is_initial;
std::tie(hidden_size, input_dimension, is_initial) = GetParam();
for (int iter = 0; iter < 100; ++iter) {
// RAII storage for vectors and matrices.
VectorMatrixStorage storage;
// Parameters for the LSTM cell, and for one run. We allocate them together
// so that it's easy to experiment with more coherent initialization
// schemes.
MatrixParameters parameters;
parameters.Init(hidden_size, input_dimension, &storage);
// Per-run inputs.
Vector<float> input(storage.RandomVector(input_dimension));
Vector<float> last_hidden(storage.RandomVector(hidden_size));
MutableVector<float> last_cell_state_mutable =
storage.RandomVector(hidden_size);
Vector<float> last_cell_state(last_cell_state_mutable);
MutableVector<float> cell_input_temp =
storage.RandomVector(3 * hidden_size);
MutableVector<float> cell_state = storage.RandomVector(hidden_size);
MutableVector<float> cell_output = storage.RandomVector(hidden_size);
MutableVector<float> next_hidden = storage.RandomVector(hidden_size);
// Outputs for un-optimized algorithm.
MutableVector<float> expected_cell_input_temp =
storage.RandomVector(3 * hidden_size);
MutableVector<float> expected_cell_state =
storage.RandomVector(hidden_size);
MutableVector<float> expected_cell_output =
storage.RandomVector(hidden_size);
MutableVector<float> expected_next_hidden =
storage.RandomVector(hidden_size);
UnoptimizedCellFunction unoptimized(hidden_size, input_dimension);
TF_EXPECT_OK(parameters.InitializeCell(&unoptimized));
TF_EXPECT_OK(unoptimized.Run(is_initial, input, last_hidden,
last_cell_state, expected_cell_input_temp,
expected_cell_state, expected_cell_output,
expected_next_hidden));
LstmCellFunction<float> cell;
TF_EXPECT_OK(parameters.InitializeCell(&cell));
TF_EXPECT_OK(cell.Run(is_initial, input, last_hidden, last_cell_state,
cell_input_temp, cell_state, cell_output,
next_hidden));
// Both this and `bfloat16_tol` below could trip EXPECTs because we are
// using random values. Adjust judiciously.
float tol = 1e-6 * hidden_size;
float bfloat16_tol = 7e-3 * hidden_size;
// Compare the first values of the cell input state.
for (int i = 0; i < hidden_size; ++i) {
EXPECT_NEAR(cell_input_temp[i], expected_cell_input_temp[i], tol);
}
// Compare the cell state, cell output, and hidden vectors.
for (int i = 0; i < hidden_size; ++i) {
EXPECT_NEAR(cell_state[i], expected_cell_state[i], tol) << " i=" << i;
EXPECT_NEAR(cell_output[i], expected_cell_output[i], tol) << " i=" << i;
EXPECT_NEAR(next_hidden[i], expected_next_hidden[i], tol) << " i=" << i;
}
// Test float16 version.
LstmCellFunction<TruncatedFloat16> bfloat16_cell;
TF_EXPECT_OK(parameters.InitializeHalfFloatCell(&storage, &bfloat16_cell));
TF_EXPECT_OK(bfloat16_cell.Run(is_initial, input, last_hidden,
last_cell_state, cell_input_temp, cell_state,
cell_output, next_hidden));
for (int i = 0; i < hidden_size; ++i) {
EXPECT_NEAR(cell_input_temp[i], expected_cell_input_temp[i],
bfloat16_tol);
EXPECT_NEAR(cell_state[i], expected_cell_state[i], bfloat16_tol);
EXPECT_NEAR(cell_output[i], expected_cell_output[i], bfloat16_tol);
EXPECT_NEAR(next_hidden[i], expected_next_hidden[i], bfloat16_tol);
}
// Check that it is OK if the cell state vector is consumed (overwritten).
TF_EXPECT_OK(cell.Run(is_initial, input, last_hidden, last_cell_state,
cell_input_temp, last_cell_state_mutable, cell_output,
next_hidden));
for (int i = 0; i < hidden_size; ++i) {
EXPECT_NEAR(last_cell_state_mutable[i], expected_cell_state[i], tol)
<< " i=" << i;
EXPECT_NEAR(cell_output[i], expected_cell_output[i], tol) << " i=" << i;
EXPECT_NEAR(next_hidden[i], expected_next_hidden[i], tol) << " i=" << i;
}
}
}
INSTANTIATE_TEST_CASE_P(CellFuzzTestInstance, CellFuzzTest,
::testing::Values(std::make_tuple(8, 32, true),
std::make_tuple(8, 32, false),
std::make_tuple(8, 17, true),
std::make_tuple(8, 17, false),
std::make_tuple(96, 32, true),
std::make_tuple(96, 32, false),
std::make_tuple(128, 32, true),
std::make_tuple(128, 32, false),
std::make_tuple(128, 173, true),
std::make_tuple(128, 173, false)));
// Test harness, with parameters hidden_size, input_dimension.
class CellInputFuzzTest
: public ::testing::TestWithParam<std::tuple<int, int>> {};
TEST_P(CellInputFuzzTest, TestBulkInputMatches) {
int hidden_size;
int input_dimension;
bool is_initial = true;
std::tie(hidden_size, input_dimension) = GetParam();
// RAII storage for vectors and matrices.
VectorMatrixStorage storage;
// Parameters for the LSTM cell, and for one run. We allocate them together
// so that it's easy to experiment with more coherent initialization
// schemes.
MatrixParameters parameters;
parameters.Init(hidden_size, input_dimension, &storage);
// Per-run inputs.
UniqueMatrix<float> inputs(2, input_dimension);
UniqueMatrix<float> cell_input_temps(2, 3 * hidden_size);
InitRandomMatrix(*inputs);
InitRandomMatrix(*cell_input_temps);
// Extra parameters for the naive algorithm, which should run everything.
Vector<float> last_hidden;
Vector<float> last_cell_state;
std::vector<MutableVector<float>> expected_cell_input_temps = {
storage.RandomVector(3 * hidden_size),
storage.RandomVector(3 * hidden_size)};
MutableVector<float> expected_cell_state = storage.RandomVector(hidden_size);
MutableVector<float> expected_cell_output = storage.RandomVector(hidden_size);
MutableVector<float> expected_next_hidden = storage.RandomVector(hidden_size);
UnoptimizedCellFunction unoptimized(hidden_size, input_dimension);
TF_EXPECT_OK(parameters.InitializeCell(&unoptimized));
TF_EXPECT_OK(unoptimized.Run(
is_initial, Vector<float>(inputs->row(0)), last_hidden, last_cell_state,
expected_cell_input_temps[0], expected_cell_state, expected_cell_output,
expected_next_hidden));
TF_EXPECT_OK(unoptimized.Run(
is_initial, Vector<float>(inputs->row(1)), last_hidden, last_cell_state,
expected_cell_input_temps[1], expected_cell_state, expected_cell_output,
expected_next_hidden));
LstmCellFunction<float> cell;
TF_EXPECT_OK(parameters.InitializeCell(&cell));
TF_EXPECT_OK(
cell.RunInputComputations(Matrix<float>(*inputs), *cell_input_temps));
// Both this and `bfloat16_tol` below could trip EXPECTs because we are using
// random values. Adjust judiciously.
float tol = 1e-7 * hidden_size;
float bfloat16_tol = 5e-3 * hidden_size;
// Compare the first values of the cell input state. If we pass
// RunInputComputation results through the sigmoid function, we should get the
// same result as calling unoptimized.Run() with is_initial = true.
for (int i = 0; i < hidden_size; ++i) {
auto sigmoid = [](float input) { return 1.0 / (1.0 + exp(-input)); };
EXPECT_NEAR(sigmoid(cell_input_temps->row(0)[i]),
expected_cell_input_temps[0][i], tol);
EXPECT_NEAR(sigmoid(cell_input_temps->row(1)[i]),
expected_cell_input_temps[1][i], tol);
}
// Test float16 version.
LstmCellFunction<TruncatedFloat16> bfloat16_cell;
TF_EXPECT_OK(parameters.InitializeHalfFloatCell(&storage, &bfloat16_cell));
TF_EXPECT_OK(bfloat16_cell.RunInputComputations(Matrix<float>(*inputs),
*cell_input_temps));
for (int i = 0; i < hidden_size; ++i) {
auto sigmoid = [](float input) { return 1.0 / (1.0 + exp(-input)); };
EXPECT_NEAR(sigmoid(cell_input_temps->row(0)[i]),
expected_cell_input_temps[0][i], bfloat16_tol);
EXPECT_NEAR(sigmoid(cell_input_temps->row(1)[i]),
expected_cell_input_temps[1][i], bfloat16_tol);
}
}
INSTANTIATE_TEST_CASE_P(CellInputFuzzTestInstance, CellInputFuzzTest,
::testing::Values(std::make_tuple(8, 32),
std::make_tuple(8, 17),
std::make_tuple(96, 32),
std::make_tuple(128, 32),
std::make_tuple(128, 173)));
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_LSTM_CELL_TEST_HELPERS_H_
#define DRAGNN_RUNTIME_LSTM_CELL_TEST_HELPERS_H_
#include "dragnn/runtime/lstm_cell/cell_function.h"
#include "dragnn/runtime/math/float16_types.h"
#include "dragnn/runtime/math/sgemvv.h"
#include "dragnn/runtime/test/helpers.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Contains storage for multiple arrays during the test. This one is
// simple/naive: it just allocates new objects, whereever malloc places them.
// A more advanced version is in cell_function_benchmark.cc, but doesn't seem
// to have much benefit yet.
class VectorMatrixStorage {
public:
VectorMatrixStorage() {}
virtual ~VectorMatrixStorage() {}
// Allocates a vector and fills it with random values.
virtual MutableVector<float> RandomVector(int size) {
vectors_.emplace_back(size);
InitRandomVector(*vectors_.back());
return *vectors_.back();
}
// Allocates a SGEMV matrix and fills it with random values. Subclasses can
// implement RandomBlockedMatrix(), which doesn't rely on a template
// parameter.
template <int batch_size>
SgemvMatrix<batch_size> RandomMatrix(int rows, int columns) {
auto blocked = RandomBlockedMatrix(rows, columns, batch_size);
SgemvMatrix<batch_size> sgemv_matrix;
TF_CHECK_OK(sgemv_matrix.Initialize(blocked.AsConst()));
return sgemv_matrix;
}
// Allocates a bfloat16 version of a matrix.
template <int batch_size>
SgemvMatrix<batch_size, TruncatedFloat16> ConvertToHalfFloat(
const SgemvMatrix<batch_size> &matrix) {
auto blocked = ConvertBlockedMatrix(matrix.matrix());
SgemvMatrix<batch_size, TruncatedFloat16> sgemv_matrix;
TF_CHECK_OK(sgemv_matrix.Initialize(blocked.AsConst()));
return sgemv_matrix;
}
protected:
virtual MutableBlockedMatrix<float,
BlockedMatrixFormat::kRowBlockedColumnMajor>
RandomBlockedMatrix(int rows, int columns, int batch_size);
virtual MutableBlockedMatrix<TruncatedFloat16,
BlockedMatrixFormat::kRowBlockedColumnMajor>
ConvertBlockedMatrix(
const BlockedMatrix<float, BlockedMatrixFormat::kRowBlockedColumnMajor>
&uncompressed);
private:
std::vector<UniqueVector<float>> vectors_;
std::vector<UniqueMatrix<float>> matrices_;
std::vector<UniqueMatrix<TruncatedFloat16>> converted_matrices_;
};
// Pulls out matrix parameters, makes them usable for multiple LSTM cell
// implementations (namely, unoptimized and normal).
struct MatrixParameters {
// Convenience aliases, since we always use the same batch size.
static constexpr int kBatchSize = LstmCellFunction<>::kBatchSize;
using CellMatrix = typename LstmCellFunction<float>::MatrixType;
void Init(int hidden_size, int input_dimension, VectorMatrixStorage *storage);
template <class CellFunction>
tensorflow::Status InitializeCell(CellFunction *cell) {
return cell->Initialize(
hidden_size, Vector<float>(cell_input_state_output_bias),
input_to_cell_input_state_output,
last_hidden_to_cell_input_state_output, last_cell_state_to_cell_input,
cell_state_to_cell_output);
}
template <class CellFunction>
tensorflow::Status InitializeHalfFloatCell(VectorMatrixStorage *storage,
CellFunction *cell) {
return cell->Initialize(
hidden_size, Vector<float>(cell_input_state_output_bias),
storage->ConvertToHalfFloat(input_to_cell_input_state_output),
storage->ConvertToHalfFloat(last_hidden_to_cell_input_state_output),
storage->ConvertToHalfFloat(last_cell_state_to_cell_input),
storage->ConvertToHalfFloat(cell_state_to_cell_output));
}
int hidden_size;
MutableVector<float> cell_input_state_output_bias;
CellMatrix input_to_cell_input_state_output;
CellMatrix last_hidden_to_cell_input_state_output;
CellMatrix last_cell_state_to_cell_input;
CellMatrix cell_state_to_cell_output;
};
// Implementation details.
inline MutableBlockedMatrix<float, BlockedMatrixFormat::kRowBlockedColumnMajor>
VectorMatrixStorage::RandomBlockedMatrix(int rows, int columns,
int batch_size) {
int rows_padded = batch_size * ((rows + batch_size - 1) / batch_size);
int num_views = rows_padded * columns / batch_size;
matrices_.emplace_back(num_views, batch_size);
auto &sgemv_storage = matrices_.back();
// Set random values. It doesn't matter that the rows/cols aren't what we
// output.
InitRandomMatrix(*sgemv_storage);
// Construct SGEMV matrix types.
MutableBlockedMatrix<float, BlockedMatrixFormat::kRowBlockedColumnMajor>
blocked;
TF_CHECK_OK(blocked.Reset(sgemv_storage.area(), rows_padded, columns));
return blocked;
}
inline void ConvertRow(Vector<float> input,
MutableVector<TruncatedFloat16> output) {
CHECK_EQ(input.size() % 16, 0);
CHECK_EQ(input.size(), output.size());
for (int i = 0; i < input.size(); ++i) {
int i_permuted = (i / 16) * 16 + FastUnpackPermutation(i % 16);
output[i] = TruncatedFloat16::DebugFromFloat(input[i_permuted]);
}
}
inline MutableBlockedMatrix<TruncatedFloat16,
BlockedMatrixFormat::kRowBlockedColumnMajor>
VectorMatrixStorage::ConvertBlockedMatrix(
const BlockedMatrix<float, BlockedMatrixFormat::kRowBlockedColumnMajor>
&uncompressed) {
converted_matrices_.emplace_back(uncompressed.num_vectors(),
uncompressed.block_size());
auto &compressed_storage = converted_matrices_.back();
MutableBlockedMatrix<TruncatedFloat16,
BlockedMatrixFormat::kRowBlockedColumnMajor>
compressed;
TF_CHECK_OK(compressed.Reset(compressed_storage.area(),
uncompressed.num_rows(),
uncompressed.num_columns()));
for (int i = 0; i < uncompressed.num_vectors(); ++i) {
ConvertRow(uncompressed.vector(i), compressed.vector(i));
}
return compressed;
}
inline void MatrixParameters::Init(int hidden_size, int input_dimension,
VectorMatrixStorage *storage) {
this->hidden_size = hidden_size;
cell_input_state_output_bias = storage->RandomVector(3 * hidden_size);
input_to_cell_input_state_output =
storage->RandomMatrix<kBatchSize>(3 * hidden_size, input_dimension);
last_hidden_to_cell_input_state_output =
storage->RandomMatrix<kBatchSize>(3 * hidden_size, hidden_size);
last_cell_state_to_cell_input =
storage->RandomMatrix<kBatchSize>(hidden_size, hidden_size);
cell_state_to_cell_output =
storage->RandomMatrix<kBatchSize>(hidden_size, hidden_size);
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_LSTM_CELL_TEST_HELPERS_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/lstm_cell/cell_function.h"
#include "dragnn/runtime/lstm_network_kernel.h"
#include "dragnn/runtime/math/avx_activation_functions.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/network_unit_base.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "dragnn/runtime/variable_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// A network unit that evaluates a LSTM.
//
// NOTE: For efficiency, unlike the Python API, lstm_h and lstm_c are not
// exposed; any subsequent components should reference 'layer_0'. This seems to
// be the case for all current DRAGNN models.
class LSTMNetwork : public NetworkUnitBase {
public:
// Implements NetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
string GetLogitsName() const override { return kernel_.GetLogitsName(); }
tensorflow::Status Evaluate(size_t step_index, SessionState *session_state,
ComputeSession *compute_session) const override;
private:
// Kernel that implements the LSTM.
LSTMNetworkKernel kernel_{/*bulk=*/false};
};
tensorflow::Status LSTMNetwork::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
TF_RETURN_IF_ERROR(kernel_.Initialize(component_spec, variable_store,
network_state_manager,
extension_manager));
const bool use_concatenated_input = true;
return InitializeBase(use_concatenated_input, component_spec, variable_store,
network_state_manager, extension_manager);
}
tensorflow::Status LSTMNetwork::Evaluate(
size_t step_index, SessionState *session_state,
ComputeSession *compute_session) const {
Vector<float> input;
TF_RETURN_IF_ERROR(EvaluateBase(session_state, compute_session, &input));
return kernel_.Apply(step_index, input, session_state);
}
DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT(LSTMNetwork);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/lstm_network_kernel.h"
#include <vector>
#include "dragnn/runtime/attributes.h"
#include "dragnn/runtime/math/avx_activation_functions.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Attributes used by the LSTM network.
struct LSTMNetworkAttributes : public Attributes {
// Hidden layer sizes; e.g. "96". LSTMNetwork only supports a single hidden
// layer size.
Mandatory<size_t> hidden_layer_sizes{"hidden_layer_sizes", this};
// Whether to omit the "logits" layer.
Optional<bool> omit_logits{"omit_logits", false, this};
// Whether to use truncated floating-point weight matrices. This incurs very
// large errors in the actual matrix multiplication, but the LSTM architecture
// seems to be mostly resilient (99.99% similar performance on the tagger).
Optional<bool> use_bfloat16_matrices{"use_bfloat16_matrices", false, this};
// Training-only attributes, ignored in the runtime.
Ignored dropout_keep_prob{"dropout_keep_prob", this};
Ignored dropout_per_sequence{"dropout_per_sequence", this};
Ignored dropout_all_layers{"dropout_all_layers", this};
Ignored initialize_bias_zero{"initialize_bias_zero", this};
Ignored initialize_softmax_zero{"initialize_softmax_zero", this};
Ignored initialize_hidden_orthogonal{"initialize_hidden_orthogonal", this};
};
// Initalizes a LstmCellFunction, using the names that are emitted by
// network_units.py's LSTMNetwork class.
template <typename MatrixElementType>
tensorflow::Status InitializeLstmCellFunction(
const ComponentSpec &component_spec, VariableStore *variable_store,
LstmCellFunction<MatrixElementType> *cell_function);
} // namespace
string LSTMNetworkKernel::GetLogitsName() const {
return has_logits_ ? FeedForwardNetworkLayer::kLogitsName : "";
}
tensorflow::Status LSTMNetworkKernel::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
// Parse network configuration.
LSTMNetworkAttributes attributes;
TF_RETURN_IF_ERROR(
attributes.Reset(component_spec.network_unit().parameters()));
has_logits_ = !TransitionSystemTraits(component_spec).is_deterministic &&
!attributes.omit_logits();
const int hidden_dimension = attributes.hidden_layer_sizes();
// Initialize the LSTM cell.
if (attributes.use_bfloat16_matrices()) {
LstmCellFunction<TruncatedFloat16> *bfloat16_cell_function =
new LstmCellFunction<TruncatedFloat16>();
cell_function_.reset(bfloat16_cell_function);
TF_RETURN_IF_ERROR(InitializeLstmCellFunction(
component_spec, variable_store, bfloat16_cell_function));
} else {
LstmCellFunction<float> *float32_cell_function =
new LstmCellFunction<float>();
cell_function_.reset(float32_cell_function);
TF_RETURN_IF_ERROR(InitializeLstmCellFunction(
component_spec, variable_store, float32_cell_function));
}
// Add a softmax to compute logits, if necessary.
if (has_logits_) {
TF_RETURN_IF_ERROR(softmax_layer_.InitializeSoftmax(
component_spec, variable_store, network_state_manager));
}
// Internal state layers.
TF_RETURN_IF_ERROR(
network_state_manager->AddLocal(hidden_dimension, &cell_state_));
TF_RETURN_IF_ERROR(
network_state_manager->AddLocal(hidden_dimension, &cell_output_));
if (bulk_) {
TF_RETURN_IF_ERROR(network_state_manager->AddLocal(
3 * hidden_dimension, &cell_input_matrix_));
} else {
TF_RETURN_IF_ERROR(network_state_manager->AddLocal(
3 * hidden_dimension, &cell_input_vector_));
}
// Layers exposed to the system.
TF_RETURN_IF_ERROR(network_state_manager->AddLayer(
"layer_0", hidden_dimension, &hidden_));
TF_RETURN_IF_ERROR(
network_state_manager->AddLayerAlias("last_layer", "layer_0"));
return tensorflow::Status::OK();
}
tensorflow::Status LSTMNetworkKernel::Apply(size_t step_index,
Vector<float> input,
SessionState *session_state) const {
DCHECK(!bulk_);
const NetworkStates &network_states = session_state->network_states;
const bool is_initial = step_index == 0;
MutableVector<float> cell_state = network_states.GetLocal(cell_state_);
MutableVector<float> cell_output = network_states.GetLocal(cell_output_);
MutableMatrix<float> hidden_all_steps = network_states.GetLayer(hidden_);
MutableVector<float> next_hidden = hidden_all_steps.row(step_index);
// c_{t-1} and h_t vectors. These will be null if not applicable, so incorrect
// code will immediately segfault.
Vector<float> last_cell_state;
Vector<float> last_hidden;
if (!is_initial) {
last_cell_state = cell_state;
last_hidden = hidden_all_steps.row(step_index - 1);
}
// Run the cell function.
MutableVector<float> cell_input = network_states.GetLocal(cell_input_vector_);
TF_RETURN_IF_ERROR(cell_function_->Run(is_initial, input, last_hidden,
last_cell_state, cell_input,
cell_state, cell_output, next_hidden));
// Compute logits, if present.
if (has_logits_) {
softmax_layer_.Apply(Vector<float>(next_hidden), network_states,
step_index);
}
return tensorflow::Status::OK();
}
tensorflow::Status LSTMNetworkKernel::Apply(Matrix<float> all_inputs,
SessionState *session_state) const {
DCHECK(bulk_);
const NetworkStates &network_states = session_state->network_states;
const size_t num_steps = all_inputs.num_rows();
MutableMatrix<float> all_cell_input_temps =
network_states.GetLocal(cell_input_matrix_);
MutableVector<float> cell_state = network_states.GetLocal(cell_state_);
MutableVector<float> cell_output = network_states.GetLocal(cell_output_);
MutableMatrix<float> all_hiddens = network_states.GetLayer(hidden_);
// SGEMVV input computation.
TF_RETURN_IF_ERROR(
cell_function_->RunInputComputations(all_inputs, all_cell_input_temps));
// Run recurrent parts of the network.
for (size_t i = 0; i < num_steps; ++i) {
const bool is_initial = i == 0;
Vector<float> last_cell_state;
Vector<float> last_hidden;
if (!is_initial) {
last_cell_state = cell_state;
last_hidden = all_hiddens.row(i - 1);
}
TF_RETURN_IF_ERROR(cell_function_->RunRecurrentComputation(
is_initial, last_hidden, last_cell_state, all_cell_input_temps.row(i),
cell_state, cell_output, all_hiddens.row(i)));
}
if (has_logits_) {
softmax_layer_.Apply(Matrix<float>(all_hiddens), network_states);
}
return tensorflow::Status::OK();
}
namespace {
// Returns a variable suffix for the |ElementType|.
template <typename ElementType>
string MatrixElementTypeSuffix();
template <>
string MatrixElementTypeSuffix<float>() {
return "";
}
template <>
string MatrixElementTypeSuffix<TruncatedFloat16>() {
return "/bfloat16";
}
// Shared logic for initializing SGEMV matrices.
template <int block_size, typename ElementType>
tensorflow::Status InitializeSgemv(
const string &weights_name, VariableStore *variable_store,
SgemvMatrix<block_size, ElementType> *sgemv_matrix) {
BlockedMatrix<ElementType> blocked_transpose;
TF_RETURN_IF_ERROR(variable_store->Lookup(
tensorflow::strings::StrCat(weights_name, "/matrix/blocked", block_size,
MatrixElementTypeSuffix<ElementType>()),
&blocked_transpose));
auto blocked = blocked_transpose.Transpose();
auto result = sgemv_matrix->Initialize(blocked);
if (result.ok()) {
LOG(INFO) << "Matrix of size " << blocked.num_rows() << " x "
<< blocked.num_columns() << " for layer " << weights_name
<< " will be computed with SGEMV<block_size=" << block_size
<< ">";
} else {
// This should (almost?) never happen, because sgemv_matrix->Initialize()
// only fails on bad block sizes, and we request the same block size from
// the variable store.
LOG(ERROR) << "Error formatting SGEMV matrix: " << result.error_message()
<< " - matrix size " << blocked.num_rows() << " x "
<< blocked.num_columns() << " for layer " << weights_name;
}
return result;
}
// Initalizes a LstmCellFunction, using the names that are emitted by
// network_units.py's LSTMNetwork class.
template <typename MatrixElementType>
tensorflow::Status InitializeLstmCellFunction(
const ComponentSpec &component_spec, VariableStore *variable_store,
LstmCellFunction<MatrixElementType> *cell_function) {
LSTMNetworkAttributes attributes;
TF_RETURN_IF_ERROR(
attributes.Reset(component_spec.network_unit().parameters()));
constexpr int kBatchSize = LstmCellFunction<>::kBatchSize;
int hidden_dimension = attributes.hidden_layer_sizes();
auto get_sgemv = [&](const string &name_suffix,
SgemvMatrix<kBatchSize, MatrixElementType> *matrix) {
string name =
tensorflow::strings::StrCat(component_spec.name(), name_suffix);
return InitializeSgemv(name, variable_store, matrix);
};
SgemvMatrix<kBatchSize, MatrixElementType> input_to_cell_input_state_output,
last_hidden_to_cell_input_state_output, last_cell_state_to_cell_input,
cell_state_to_cell_output;
TF_RETURN_IF_ERROR(get_sgemv("/x_to_ico", &input_to_cell_input_state_output));
TF_RETURN_IF_ERROR(
get_sgemv("/h_to_ico", &last_hidden_to_cell_input_state_output));
TF_RETURN_IF_ERROR(get_sgemv("/c2i", &last_cell_state_to_cell_input));
TF_RETURN_IF_ERROR(get_sgemv("/c2o", &cell_state_to_cell_output));
string ico_bias_name =
tensorflow::strings::StrCat(component_spec.name(), "/", "ico_bias");
Vector<float> ico_bias;
TF_RETURN_IF_ERROR(variable_store->Lookup(ico_bias_name, &ico_bias));
return cell_function->Initialize(
hidden_dimension, ico_bias, input_to_cell_input_state_output,
last_hidden_to_cell_input_state_output, last_cell_state_to_cell_input,
cell_state_to_cell_output);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_LSTM_NETWORK_KERNEL_H_
#define DRAGNN_RUNTIME_LSTM_NETWORK_KERNEL_H_
#include <memory>
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/feed_forward_network_layer.h"
#include "dragnn/runtime/lstm_cell/cell_function.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Kernel that evaluates an LSTM network.
class LSTMNetworkKernel {
public:
// Creates a kernel for bulk or non-bulk computations.
explicit LSTMNetworkKernel(bool bulk) : bulk_(bulk) {}
// Initializes this to the configuration in the |component_spec|. Retrieves
// pre-trained variables from the |variable_store|, which must outlive this.
// Adds layers and local operands to the |network_state_manager|, which must
// be positioned at the current component. Requests SessionState extensions
// from the |extension_manager|. On error, returns non-OK.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager);
// Returns the name of the logits layer, or an empty string if none.
string GetLogitsName() const;
// Applies this to the |input| activations for the |step_index|'th step using
// the |session_state|. Requires that this was created in non-bulk mode. On
// error, returns non-OK.
tensorflow::Status Apply(size_t step_index, Vector<float> input,
SessionState *session_state) const;
// As above, but for matrices. Requires that this was created in bulk mode.
tensorflow::Status Apply(Matrix<float> all_inputs,
SessionState *session_state) const;
private:
// Whether this is a bulk or non-bulk kernel.
const bool bulk_;
// Whether this has a logits layer.
bool has_logits_ = false;
// Main cell function, which is an instance of either LstmCellFunction<float>
// or LstmCellFunctionBase<TruncatedFloat16>.
std::unique_ptr<LstmCellFunctionBase> cell_function_;
// LSTM cell state and output.
LocalVectorHandle<float> cell_state_;
LocalVectorHandle<float> cell_output_;
// LSTM cell input. Only used if |bulk_| is false.
LocalVectorHandle<float> cell_input_vector_;
// LSTM cell input. Only used if |bulk_| is true.
LocalMatrixHandle<float> cell_input_matrix_;
// Hidden outputs.
LayerHandle<float> hidden_;
// The softmax is an affine transformation of the hidden state.
FeedForwardNetworkLayer softmax_layer_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_LSTM_NETWORK_KERNEL_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/lstm_network_kernel.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/lstm_cell/cell_function.h"
#include "dragnn/runtime/test/helpers.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr size_t kNumSteps = 20;
constexpr size_t kNumActions = 10;
// Testing rig, parameterized on a bool that indicates whether the kernel is
// created in bulk mode.
class LSTMNetworkKernelTest : public NetworkTestBase,
public ::testing::WithParamInterface<bool> {
protected:
// Returns true if the |kernel_| was created in bulk mode.
bool bulk() const { return GetParam(); }
// Adds a blocked weight matrix with the |name| with the given dimensions and
// |fill_value|. If |is_flexible_matrix| is true, the variable is set up for
// use by the FlexibleMatrixKernel.
void AddWeights(const string &name, size_t input_dim, size_t output_dim,
float fill_value, bool is_flexible_matrix = false) {
constexpr int kBatchSize = LstmCellFunction<>::kBatchSize;
size_t output_padded =
kBatchSize * ((output_dim + kBatchSize - 1) / kBatchSize);
size_t num_views = (output_padded / kBatchSize) * input_dim;
string var_name = tensorflow::strings::StrCat(
kTestComponentName, "/", name,
is_flexible_matrix ? FlexibleMatrixKernel::kSuffix
: "/matrix/blocked48");
const std::vector<float> block(kBatchSize, fill_value);
const std::vector<std::vector<float>> blocks(num_views, block);
variable_store_.AddOrDie(
var_name, blocks, VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX);
variable_store_.SetBlockedDimensionOverride(
var_name, {input_dim, output_padded, kBatchSize});
}
// Adds a bias vector with the |name_suffix| with the given dimensions and
// |fill_value|.
void AddBiases(const string &name, size_t dimension, float fill_value) {
const string biases_name =
tensorflow::strings::StrCat(kTestComponentName, "/", name);
AddVectorVariable(biases_name, dimension, fill_value);
}
// Initializes the |kernel_| from the |component_spec_text|. On error,
// returns non-OK.
tensorflow::Status Initialize(const string &component_spec_text) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
component_spec.set_name(kTestComponentName);
// Since LSTMNetworkKernel uses the concatenated input, it is insensitive
// to the particular fixed or linked embedding inputs. For simplicity, the
// tests use a trivial network structure and a single fixed embedding.
AddComponent(kTestComponentName);
TF_RETURN_IF_ERROR(kernel_.Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
session_state_.extensions.Reset(&extension_manager_);
return tensorflow::Status::OK();
}
// Applies the |kernel_| to the |inputs|.
void Apply(const std::vector<std::vector<float>> &inputs) {
UniqueMatrix<float> input_matrix(inputs);
if (bulk()) {
TF_ASSERT_OK(
kernel_.Apply(Matrix<float>(*input_matrix), &session_state_));
} else {
for (size_t step_index = 0; step_index < kNumSteps; ++step_index) {
TF_ASSERT_OK(kernel_.Apply(step_index,
Vector<float>(input_matrix->row(step_index)),
&session_state_));
}
}
}
// Returns the logits matrix.
Matrix<float> GetLogits() const {
return Matrix<float>(GetLayer(kTestComponentName, "logits"));
}
LSTMNetworkKernel kernel_{bulk()};
};
INSTANTIATE_TEST_CASE_P(BulkMode, LSTMNetworkKernelTest, ::testing::Bool());
// Tests that the LSTMNetworkKernel does not produce logits when omit_logits is
// true, even if there are actions.
TEST_P(LSTMNetworkKernelTest, NoLogitsOrSoftmaxWhenOmitLogitsTrue) {
constexpr size_t input_dim = 32;
constexpr int kHiddenDim = LstmCellFunction<>::kBatchSize;
const string kSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 32
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '48'
}
parameters {
key: 'omit_logits'
value: 'true'
}
}
num_actions: 10)";
constexpr float kEmbedding = 1.25;
constexpr float kWeight = 1.5;
// No "softmax" weights or biases.
AddWeights("x_to_ico", input_dim, 3 * kHiddenDim, kWeight);
AddWeights("h_to_ico", kHiddenDim, 3 * kHiddenDim, kWeight);
AddWeights("c2i", kHiddenDim, kHiddenDim, kWeight);
AddWeights("c2o", kHiddenDim, kHiddenDim, kWeight);
AddBiases("ico_bias", 3 * kHiddenDim, kWeight);
TF_ASSERT_OK(Initialize(kSpec));
// No specified logits layer.
EXPECT_TRUE(kernel_.GetLogitsName().empty());
const std::vector<float> row(input_dim, kEmbedding);
const std::vector<std::vector<float>> rows(kNumSteps, row);
Apply(rows);
// No "logits" layer.
size_t unused_dimension = 0;
LayerHandle<float> unused_handle;
EXPECT_THAT(
network_state_manager_.LookupLayer(kTestComponentName, "logits",
&unused_dimension, &unused_handle),
test::IsErrorWithSubstr(
"Unknown layer 'logits' in component 'test_component'"));
}
TEST_P(LSTMNetworkKernelTest, NormalOperationSmallHidden) {
constexpr size_t input_dim = 32;
constexpr int kHiddenDim = 8;
const string kSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 32
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '8'
}
}
num_actions: 10)";
constexpr float kEmbedding = 1.25;
constexpr float kWeight = 1.5;
// Same as above, with "softmax" weights and biases.
AddWeights("x_to_ico", input_dim, 3 * kHiddenDim, kWeight);
AddWeights("h_to_ico", kHiddenDim, 3 * kHiddenDim, kWeight);
AddWeights("c2i", kHiddenDim, kHiddenDim, kWeight);
AddWeights("c2o", kHiddenDim, kHiddenDim, kWeight);
AddWeights("weights_softmax", kHiddenDim, kNumActions, kWeight,
/*is_flexible_matrix=*/true);
AddBiases("ico_bias", 3 * kHiddenDim, kWeight);
AddBiases("bias_softmax", kNumActions, kWeight);
TF_EXPECT_OK(Initialize(kSpec));
// Logits should exist.
EXPECT_EQ(kernel_.GetLogitsName(), "logits");
const std::vector<float> row(input_dim, kEmbedding);
const std::vector<std::vector<float>> rows(kNumSteps, row);
Apply(rows);
// Logits dimension matches "num_actions" above. We don't test the values very
// precisely here, and feel free to update if the cell function changes. Most
// value tests should be in lstm_cell/cell_function_test.cc.
Matrix<float> logits = GetLogits();
EXPECT_EQ(logits.num_rows(), kNumSteps);
EXPECT_EQ(logits.num_columns(), kNumActions);
EXPECT_NEAR(logits.row(0)[0], 10.6391, 0.1);
for (int row = 0; row < logits.num_rows(); ++row) {
for (const float value : logits.row(row)) {
EXPECT_EQ(value, logits.row(0)[0])
<< "With uniform weights, all logits should be equal.";
}
}
}
TEST_P(LSTMNetworkKernelTest, ErrorWithTooSmallHidden) {
constexpr size_t input_dim = 32;
constexpr int kHiddenDim = 4;
const string kSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 32
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '4'
}
}
num_actions: 0)";
constexpr float kEmbedding = 1.25;
constexpr float kWeight = 1.5;
AddFixedEmbeddingMatrix(0, 50, input_dim, kEmbedding);
// Same as above, with "softmax" weights and biases.
AddWeights("x_to_ico", input_dim, 3 * kHiddenDim, kWeight);
AddWeights("h_to_ico", kHiddenDim, 3 * kHiddenDim, kWeight);
AddWeights("c2i", kHiddenDim, kHiddenDim, kWeight);
AddWeights("c2o", kHiddenDim, kHiddenDim, kWeight);
AddBiases("ico_bias", 3 * kHiddenDim, kWeight);
EXPECT_THAT(
Initialize(kSpec),
test::IsErrorWithSubstr(
"Expected hidden size (4) to be a multiple of the AVX width (8)"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment