Commit edea2b67 authored by Terry Koo's avatar Terry Koo
Browse files

Remove runtime because reasons.

parent a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/lstm_network_kernel.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// A network unit that evaluates an LSTM.
class BulkLSTMNetwork : public BulkNetworkUnit {
public:
// Implements BulkNetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return kernel_.Initialize(component_spec, variable_store,
network_state_manager, extension_manager);
}
tensorflow::Status ValidateInputDimension(size_t dimension) const override {
return tensorflow::Status::OK();
}
string GetLogitsName() const override { return kernel_.GetLogitsName(); }
tensorflow::Status Evaluate(Matrix<float> inputs,
SessionState *session_state) const override {
return kernel_.Apply(inputs, session_state);
}
private:
// Kernel that implements the LSTM.
LSTMNetworkKernel kernel_{/*bulk=*/true};
};
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT(BulkLSTMNetwork);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/flexible_matrix_kernel.h"
#include "dragnn/runtime/lstm_cell/cell_function.h"
#include "dragnn/runtime/test/helpers.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr size_t kNumSteps = 20;
constexpr size_t kNumActions = 10;
constexpr size_t kInputDim = 32;
constexpr size_t kHiddenDim = 8;
class BulkLSTMNetworkTest : public NetworkTestBase {
protected:
// Adds a blocked weight matrix with the |name| with the given dimensions and
// |fill_value|. If |is_flexible_matrix| is true, the variable is set up for
// use by the FlexibleMatrixKernel.
void AddWeights(const string &name, size_t input_dim, size_t output_dim,
float fill_value, bool is_flexible_matrix = false) {
constexpr int kBatchSize = LstmCellFunction<>::kBatchSize;
size_t output_padded =
kBatchSize * ((output_dim + kBatchSize - 1) / kBatchSize);
size_t num_views = (output_padded / kBatchSize) * input_dim;
string var_name = tensorflow::strings::StrCat(
kTestComponentName, "/", name,
is_flexible_matrix ? FlexibleMatrixKernel::kSuffix
: "/matrix/blocked48");
const std::vector<float> block(kBatchSize, fill_value);
const std::vector<std::vector<float>> blocks(num_views, block);
variable_store_.AddOrDie(
var_name, blocks, VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX);
variable_store_.SetBlockedDimensionOverride(
var_name, {input_dim, output_padded, kBatchSize});
}
// Adds a bias vector with the |name_suffix| with the given dimensions and
// |fill_value|.
void AddBiases(const string &name, size_t dimension, float fill_value) {
const string biases_name =
tensorflow::strings::StrCat(kTestComponentName, "/", name);
AddVectorVariable(biases_name, dimension, fill_value);
}
// Initializes the |bulk_network_unit_| from the |component_spec_text|. On
// error, returns non-OK.
tensorflow::Status Initialize(const string &component_spec_text) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
component_spec.set_name(kTestComponentName);
AddComponent(kTestComponentName);
TF_RETURN_IF_ERROR(
BulkNetworkUnit::CreateOrError("BulkLSTMNetwork", &bulk_network_unit_));
TF_RETURN_IF_ERROR(bulk_network_unit_->Initialize(
component_spec, &variable_store_, &network_state_manager_,
&extension_manager_));
TF_RETURN_IF_ERROR(bulk_network_unit_->ValidateInputDimension(kInputDim));
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
session_state_.extensions.Reset(&extension_manager_);
return tensorflow::Status::OK();
}
// Evaluates the |bulk_network_unit_| on the |inputs|.
void Apply(const std::vector<std::vector<float>> &inputs) {
UniqueMatrix<float> input_matrix(inputs);
TF_ASSERT_OK(bulk_network_unit_->Evaluate(Matrix<float>(*input_matrix),
&session_state_));
}
// Returns the logits matrix.
Matrix<float> GetLogits() const {
return Matrix<float>(GetLayer(kTestComponentName, "logits"));
}
std::unique_ptr<BulkNetworkUnit> bulk_network_unit_;
};
TEST_F(BulkLSTMNetworkTest, NormalOperation) {
const string kSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 32
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '8'
}
}
num_actions: 10)";
constexpr float kEmbedding = 1.25;
constexpr float kWeight = 1.5;
// Same as above, with "softmax" weights and biases.
AddWeights("x_to_ico", kInputDim, 3 * kHiddenDim, kWeight);
AddWeights("h_to_ico", kHiddenDim, 3 * kHiddenDim, kWeight);
AddWeights("c2i", kHiddenDim, kHiddenDim, kWeight);
AddWeights("c2o", kHiddenDim, kHiddenDim, kWeight);
AddWeights("weights_softmax", kHiddenDim, kNumActions, kWeight,
/*is_flexible_matrix=*/true);
AddBiases("ico_bias", 3 * kHiddenDim, kWeight);
AddBiases("bias_softmax", kNumActions, kWeight);
TF_EXPECT_OK(Initialize(kSpec));
// Logits should exist.
EXPECT_EQ(bulk_network_unit_->GetLogitsName(), "logits");
const std::vector<float> row(kInputDim, kEmbedding);
const std::vector<std::vector<float>> rows(kNumSteps, row);
Apply(rows);
// Logits dimension matches "num_actions" above. We don't test the values very
// precisely here, and feel free to update if the cell function changes. Most
// value tests should be in lstm_cell/cell_function_test.cc.
Matrix<float> logits = GetLogits();
EXPECT_EQ(logits.num_rows(), kNumSteps);
EXPECT_EQ(logits.num_columns(), kNumActions);
EXPECT_NEAR(logits.row(0)[0], 10.6391, 0.1);
for (int row = 0; row < logits.num_rows(); ++row) {
for (const float value : logits.row(row)) {
EXPECT_EQ(value, logits.row(0)[0])
<< "With uniform weights, all logits should be equal.";
}
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/bulk_network_unit.h"
#include <vector>
#include "dragnn/runtime/network_unit.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
string BulkNetworkUnit::GetClassName(
const ComponentSpec &component_spec) {
// The network unit name specified in the |component_spec| is for the Python
// registry and cannot be passed directly to the C++ registry. The function
// below extracts the C++ registered name; e.g.,
// "some.module.FooNetwork" => "FooNetwork".
// We then prepend "Bulk" to distinguish it from the non-bulk version.
return tensorflow::strings::StrCat("Bulk",
NetworkUnit::GetClassName(component_spec));
}
} // namespace runtime
} // namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Bulk Network Unit",
dragnn::runtime::BulkNetworkUnit);
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_BULK_NETWORK_UNIT_H_
#define DRAGNN_RUNTIME_BULK_NETWORK_UNIT_H_
#include <stddef.h>
#include <functional>
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Interface for network units for bulk inference.
//
// TODO(googleuser): The current approach assumes that fixed and
// linked embeddings are computed and concatenated outside the network unit,
// which is simple and composable. However, it could be more efficient to,
// e.g., pass the fixed and linked embeddings individually or compute them
// internally. That would elide the concatenation and could increase cache
// coherency.
class BulkNetworkUnit : public RegisterableClass<BulkNetworkUnit> {
public:
BulkNetworkUnit(const BulkNetworkUnit &that) = delete;
BulkNetworkUnit &operator=(const BulkNetworkUnit &that) = delete;
virtual ~BulkNetworkUnit() = default;
// Returns the bulk network unit class name specified in the |component_spec|.
static string GetClassName(const ComponentSpec &component_spec);
// Initializes this to the configuration in the |component_spec|. Retrieves
// pre-trained variables from the |variable_store|, which must outlive this.
// Adds layers and local operands to the |network_state_manager|, which must
// be positioned at the current component. Requests SessionState extensions
// from the |extension_manager|. On error, returns non-OK.
virtual tensorflow::Status Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) = 0;
// Returns OK iff this is compatible with the input |dimension|.
virtual tensorflow::Status ValidateInputDimension(size_t dimension) const = 0;
// Returns the name of the layer that contains classification logits, or an
// empty string if this does not produce logits. Requires that Initialize()
// was called.
virtual string GetLogitsName() const = 0;
// Evaluates this network on the bulk |inputs|, using intermediate operands
// and output layers in the |session_state|. On error, returns non-OK.
virtual tensorflow::Status Evaluate(Matrix<float> inputs,
SessionState *session_state) const = 0;
protected:
BulkNetworkUnit() = default;
private:
// Helps prevent use of the Create() method; use CreateOrError() instead.
using RegisterableClass<BulkNetworkUnit>::Create;
};
} // namespace runtime
} // namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Bulk Network Unit",
dragnn::runtime::BulkNetworkUnit);
} // namespace syntaxnet
// Registers a subclass using its class name as a string.
#define DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::BulkNetworkUnit, #subclass, subclass)
#endif // DRAGNN_RUNTIME_BULK_NETWORK_UNIT_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/bulk_network_unit.h"
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Expects that the two pointers have the same address.
void ExpectSameAddress(const void *pointer1, const void *pointer2) {
EXPECT_EQ(pointer1, pointer2);
}
// A trivial implementation for tests.
class BulkFooNetwork : public BulkNetworkUnit {
public:
// Implements BulkNetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return tensorflow::Status::OK();
}
tensorflow::Status ValidateInputDimension(size_t dimension) const override {
return tensorflow::Status::OK();
}
string GetLogitsName() const override { return "foo_logits"; }
tensorflow::Status Evaluate(Matrix<float> inputs,
SessionState *session_state) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT(BulkFooNetwork);
// Tests that BulkNetworkUnit::GetClassName() resolves names properly.
TEST(BulkNetworkUnitTest, GetClassName) {
for (const string &registered_name :
{"FooNetwork",
"module.FooNetwork",
"some.long.path.to.module.FooNetwork"}) {
ComponentSpec component_spec;
component_spec.mutable_network_unit()->set_registered_name(registered_name);
EXPECT_EQ(BulkNetworkUnit::GetClassName(component_spec), "BulkFooNetwork");
}
}
// Tests that BulkNetworkUnits can be created via the registry.
TEST(BulkNetworkUnitTest, CreateOrError) {
std::unique_ptr<BulkNetworkUnit> foo;
TF_ASSERT_OK(BulkNetworkUnit::CreateOrError("BulkFooNetwork", &foo));
ASSERT_TRUE(foo != nullptr);
ExpectSameAddress(dynamic_cast<BulkFooNetwork *>(foo.get()), foo.get());
EXPECT_EQ(foo->GetLogitsName(), "foo_logits");
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Transformer that removes dropout settings.
class ClearDropoutComponentTransformer : public ComponentTransformer {
public:
// Implements ComponentTransformer.
tensorflow::Status Transform(const string &component_type,
ComponentSpec *component_spec) override {
for (FixedFeatureChannel &channel :
*component_spec->mutable_fixed_feature()) {
channel.clear_dropout_id();
channel.clear_dropout_keep_probability();
}
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(ClearDropoutComponentTransformer);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Tests that a spec with no dropout features is unmodified.
TEST(ClearDropoutComponentTransformerTest, DoesNotModifyIfNoDropout) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("foo");
component_spec.add_fixed_feature()->set_name("words");
const ComponentSpec expected_spec = component_spec;
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(expected_spec));
}
// Tests that a spec with dropout features is modified.
TEST(ClearDropoutComponentTransformerTest, ClearsDropout) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("foo");
FixedFeatureChannel *channel = component_spec.add_fixed_feature();
channel->set_name("words");
channel->set_dropout_id(100);
channel->add_dropout_keep_probability(1.0);
channel->add_dropout_keep_probability(0.5);
channel->add_dropout_keep_probability(0.1);
ComponentSpec expected_spec = component_spec;
expected_spec.clear_fixed_feature();
expected_spec.add_fixed_feature()->set_name("words");
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(expected_spec));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/component.h"
#include <memory>
#include <utility>
#include <vector>
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
string GetNormalizedComponentBuilderName(const ComponentSpec &component_spec) {
// The Python registration API is based on (relative) module paths, such as
// "some.module.FooComponent". Discard the module path prefix and use only
// the final segment, which is the subclass name.
const std::vector<string> segments = tensorflow::str_util::Split(
component_spec.component_builder().registered_name(), ".");
CHECK_GT(segments.size(), 0) << "No builder name for component spec: "
<< component_spec.ShortDebugString();
tensorflow::StringPiece subclass_name = segments.back();
// In addition, remove a "Builder" suffix, if any. In the Python codebase, a
// ComponentBuilder builds a TF graph to perform some computation, whereas in
// the runtime, a Component directly executes that computation.
tensorflow::str_util::ConsumeSuffix(&subclass_name, "Builder");
return subclass_name.ToString();
}
tensorflow::Status Component::Select(const ComponentSpec &spec,
string *result) {
const string normalized_builder_name =
GetNormalizedComponentBuilderName(spec);
// Iterate through all registered components, constructing them and querying
// their Supports() methods.
std::unique_ptr<Component> current_best;
string current_best_name;
for (const Registry::Registrar *component = registry()->components;
component != nullptr; component = component->next()) {
// component->object() is a function pointer to the subclass' constructor.
std::unique_ptr<Component> next(component->object()());
string next_name(component->name());
if (!next->Supports(spec, normalized_builder_name)) {
continue;
}
// First supported component.
if (current_best == nullptr) {
current_best = std::move(next);
current_best_name = next_name;
continue;
}
// The two must agree on which takes precedence.
if (next->PreferredTo(*current_best)) {
if (current_best->PreferredTo(*next)) {
return tensorflow::errors::FailedPrecondition(
"Classes '", current_best_name, "' and '", next_name,
"' both think they should be preferred to each-other. Please "
"add logic to their PreferredTo() methods to avoid this.");
}
current_best = std::move(next);
current_best_name = next_name;
} else if (!current_best->PreferredTo(*next)) {
return tensorflow::errors::FailedPrecondition(
"Classes '", current_best_name, "' and '", next_name,
"' both think they should be dis-preferred to each-other. Please "
"add logic to their PreferredTo() methods to avoid this.");
}
}
if (current_best == nullptr) {
return tensorflow::errors::NotFound(
"Could not find a best spec for component '", spec.name(),
"' with normalized builder name '", normalized_builder_name, "'");
} else {
*result = std::move(current_best_name);
return tensorflow::Status::OK();
}
}
} // namespace runtime
} // namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Component",
dragnn::runtime::Component);
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_COMPONENT_H_
#define DRAGNN_RUNTIME_COMPONENT_H_
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Helper method, currently only used by myelination.cc.
string GetNormalizedComponentBuilderName(const ComponentSpec &component_spec);
// Interface for components.
class Component : public RegisterableClass<Component> {
public:
Component(const Component &that) = delete;
Component &operator=(const Component &that) = delete;
virtual ~Component() = default;
// Initializes this to the configuration in the |component_spec|. Retrieves
// pre-trained variables from the |variable_store|, which must outlive this.
// Adds layers and local operands to the |network_state_manager|, which must
// be positioned at the current component. Requests SessionState extensions
// from the |extension_manager|. On error, returns non-OK.
virtual tensorflow::Status Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) = 0;
// Evaluates this on the |session_state| and |compute_session|, which must
// both be positioned at the current component. If |component_trace| is
// non-null, overwrites it with extracted traces. On error, returns non-OK.
virtual tensorflow::Status Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const = 0;
// Returns the best component for a spec, searching through all registered
// subclasses. This allows specialized implementations to be used.
//
// Sets |result| on success, otherwise returns an error message if a single
// best matching component could not be found. Returned statuses include:
// * OK: If a single best matching component was found.
// * FAILED_PRECONDITION: If an error occurred during the search.
// * NOT_FOUND: If the search was error-free, but no matches were found.
static tensorflow::Status Select(const ComponentSpec &spec, string *result);
protected:
Component() = default;
// Whether this component supports a given spec. |spec| is the full component
// spec and |normalized_builder_name| is the component builder name, with
// Python modules and the suffix "Builder" stripped.
virtual bool Supports(const ComponentSpec &spec,
const string &normalized_builder_name) const = 0;
// Whether to prefer this component to another. (Both components must say that
// they support the spec.)
//
// Components must agree on whether they are more or less specialized than
// another component. Feel free to expose methods for subclasses to identify
// themselves; currently, we only have unoptimized implementations (which say
// they are never preferred) and optimized implementations (which say they are
// always preferred).
virtual bool PreferredTo(const Component &other) const = 0;
private:
// Helps prevent use of the Create() method; use CreateOrError() instead.
using RegisterableClass<Component>::Create;
};
} // namespace runtime
} // namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Component",
dragnn::runtime::Component);
} // namespace syntaxnet
// Registers a subclass using its class name as a string.
#define DRAGNN_RUNTIME_REGISTER_COMPONENT(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT(::syntaxnet::dragnn::runtime::Component, \
#subclass, subclass)
#endif // DRAGNN_RUNTIME_COMPONENT_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/component.h"
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Expects that the two pointers have the same address.
void ExpectSameAddress(const void *pointer1, const void *pointer2) {
EXPECT_EQ(pointer1, pointer2);
}
// A trivial implementation for tests.
class FooComponent : public Component {
public:
// Implements Component.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return tensorflow::Status::OK();
}
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override {
return tensorflow::Status::OK();
}
bool Supports(const ComponentSpec &spec,
const string &normalized_builder_name) const override {
return normalized_builder_name == "FooComponent";
}
bool PreferredTo(const Component &other) const override { return false; }
};
DRAGNN_RUNTIME_REGISTER_COMPONENT(FooComponent);
// Class that always says it's preferred.
class ImTheBest1 : public FooComponent {
public:
bool Supports(const ComponentSpec &spec,
const string &normalized_builder_name) const override {
return normalized_builder_name == "ImTheBest";
}
bool PreferredTo(const Component &other) const override { return true; }
};
class ImTheBest2 : public ImTheBest1 {};
DRAGNN_RUNTIME_REGISTER_COMPONENT(ImTheBest1);
DRAGNN_RUNTIME_REGISTER_COMPONENT(ImTheBest2);
// Class that always says it's dispreferred.
class ImTheWorst1 : public FooComponent {
public:
bool Supports(const ComponentSpec &spec,
const string &normalized_builder_name) const override {
return normalized_builder_name == "ImTheWorst";
}
bool PreferredTo(const Component &other) const override { return false; }
};
class ImTheWorst2 : public ImTheWorst1 {};
DRAGNN_RUNTIME_REGISTER_COMPONENT(ImTheWorst1);
DRAGNN_RUNTIME_REGISTER_COMPONENT(ImTheWorst2);
// Specialized foo implementation. We use debug-mode down-casting to check that
// the correct sub-class was instantiated.
class SpecializedFooComponent : public Component {
public:
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return tensorflow::Status::OK();
}
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override {
return tensorflow::Status::OK();
}
bool Supports(const ComponentSpec &spec,
const string &normalized_builder_name) const override {
return normalized_builder_name == "FooComponent" && spec.num_actions() == 1;
}
bool PreferredTo(const Component &other) const override { return true; }
};
DRAGNN_RUNTIME_REGISTER_COMPONENT(SpecializedFooComponent);
TEST(ComponentTest, NameResolutionError) {
ComponentSpec component_spec;
EXPECT_DEATH(GetNormalizedComponentBuilderName(component_spec),
"No builder name for component spec");
}
// Tests that Python-esque module specifiers for builders are normalized
// appropriately.
TEST(ComponentTest, VariantsOfComponentBuilderNameResolve) {
for (const string &registered_name :
{"FooComponent",
"FooComponentBuilder",
"module.FooComponent",
"module.FooComponentBuilder",
"some.long.path.to.module.FooComponent",
"some.long.path.to.module.FooComponentBuilder"}) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name(
registered_name);
string result;
TF_ASSERT_OK(Component::Select(component_spec, &result));
EXPECT_EQ(result, "FooComponent");
}
}
TEST(ComponentTest, ErrorWithBothPreferred) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("ImTheBest");
string result;
EXPECT_THAT(
Component::Select(component_spec, &result),
test::IsErrorWithCodeAndSubstr(tensorflow::error::FAILED_PRECONDITION,
"Classes 'ImTheBest2' and 'ImTheBest1' "
"both think they should be preferred to "
"each-other. Please add logic to their "
"PreferredTo() methods to avoid this."));
}
TEST(ComponentTest, ErrorWithNeitherPreferred) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("ImTheWorst");
string result;
EXPECT_THAT(Component::Select(component_spec, &result),
test::IsErrorWithCodeAndSubstr(
tensorflow::error::FAILED_PRECONDITION,
"Classes 'ImTheWorst2' and 'ImTheWorst1' both think they "
"should be dis-preferred to each-other. Please add logic to "
"their PreferredTo() methods to avoid this."));
}
TEST(ComponentTest, DefaultComponent) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name(
"FooComponent");
component_spec.set_num_actions(45);
string result;
TF_EXPECT_OK(Component::Select(component_spec, &result));
EXPECT_EQ(result, "FooComponent");
}
TEST(ComponentTest, SpecializedComponent) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name(
"FooComponent");
component_spec.set_num_actions(1);
string result;
TF_EXPECT_OK(Component::Select(component_spec, &result));
EXPECT_EQ(result, "SpecializedFooComponent");
}
// Tests that Select() returns NOT_FOUND when there is no matching component.
TEST(ComponentTest, NoMatchingComponentNotFound) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("unknown");
string result;
EXPECT_THAT(Component::Select(component_spec, &result),
test::IsErrorWithCodeAndSubstr(
tensorflow::error::NOT_FOUND,
"Could not find a best spec for component"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/component_transformation.h"
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/runtime/component.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status TransformComponents(const string &input_master_spec_path,
const string &output_master_spec_path) {
MasterSpec master_spec;
TF_RETURN_IF_ERROR(tensorflow::ReadTextProto(
tensorflow::Env::Default(), input_master_spec_path, &master_spec));
for (ComponentSpec &component_spec : *master_spec.mutable_component()) {
TF_RETURN_IF_ERROR(ComponentTransformer::ApplyAll(&component_spec));
}
return tensorflow::WriteTextProto(tensorflow::Env::Default(),
output_master_spec_path, master_spec);
}
tensorflow::Status ComponentTransformer::ApplyAll(
ComponentSpec *component_spec) {
// Limit on the number of iterations, to prevent infinite loops.
static constexpr int kMaxNumIterations = 1000;
std::set<string> names; // sorted for determinism
for (const Registry::Registrar *registrar = registry()->components;
registrar != nullptr; registrar = registrar->next()) {
names.insert(registrar->name());
}
std::vector<std::unique_ptr<ComponentTransformer>> transformers;
transformers.reserve(names.size());
for (const string &name : names) transformers.emplace_back(Create(name));
ComponentSpec local_spec = *component_spec; // avoid modification on error
for (int iteration = 0; iteration < kMaxNumIterations; ++iteration) {
const ComponentSpec original_spec = local_spec;
for (const auto &transformer : transformers) {
const string component_type =
GetNormalizedComponentBuilderName(local_spec);
TF_RETURN_IF_ERROR(transformer->Transform(component_type, &local_spec));
}
if (tensorflow::protobuf::util::MessageDifferencer::Equals(local_spec,
original_spec)) {
// Converged successfully; make modifications.
*component_spec = local_spec;
return tensorflow::Status::OK();
}
}
return tensorflow::errors::Internal("Failed to converge within ",
kMaxNumIterations,
" ComponentTransformer iterations");
}
} // namespace runtime
} // namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Component Transformer",
dragnn::runtime::ComponentTransformer);
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for transforming ComponentSpecs, typically (but not necessarily) in
// ways that are intended to improve speed. For example, a transformer might
// detect a favorable component configuration and replace a generic Component
// implementation with a faster version.
#ifndef DRAGNN_RUNTIME_COMPONENT_TRANSFORMATION_H_
#define DRAGNN_RUNTIME_COMPONENT_TRANSFORMATION_H_
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Loads a MasterSpec from the |input_master_spec_path|, applies all registered
// ComponentTransformers to it (see ComponentTransformer::ApplyAll() below), and
// writes it to the |output_master_spec_path|. On error, returns non-OK.
//
// Side note: This function has a file-path-based API so it can be easily
// wrapped in a stand-alone binary.
tensorflow::Status TransformComponents(const string &input_master_spec_path,
const string &output_master_spec_path);
// Interface for modules that can transform a ComponentSpec, which allows
// transformations to be plugged in on a decentralized basis.
class ComponentTransformer : public RegisterableClass<ComponentTransformer> {
public:
ComponentTransformer(const ComponentTransformer &that) = delete;
ComponentTransformer &operator=(const ComponentTransformer &that) = delete;
virtual ~ComponentTransformer() = default;
// Repeatedly loops through all registered transformers and applies them to
// the |component_spec| until no more changes occur. For determinism, each
// loop applies the transformers in ascending order of registered name. On
// error, returns non-OK and modifies nothing.
static tensorflow::Status ApplyAll(ComponentSpec *component_spec);
protected:
ComponentTransformer() = default;
private:
// Helps prevent use of the Create() method.
using RegisterableClass<ComponentTransformer>::Create;
// Modifies the |component_spec|, which is currently configured to use the
// |component_type|, if compatible. On error, returns non-OK and modifies
// nothing. Note that it is not an error if the |component_spec| is simply
// not compatible with the desired transformation.
virtual tensorflow::Status Transform(const string &component_type,
ComponentSpec *component_spec) = 0;
};
} // namespace runtime
} // namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Component Transformer",
dragnn::runtime::ComponentTransformer);
} // namespace syntaxnet
#define DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::ComponentTransformer, #subclass, subclass)
#endif // DRAGNN_RUNTIME_COMPONENT_TRANSFORMATION_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/component_transformation.h"
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Transformer that fails if the component type is "fail".
class MaybeFail : public ComponentTransformer {
public:
// Implements ComponentTransformer.
tensorflow::Status Transform(const string &component_type,
ComponentSpec *) override {
if (component_type == "fail") {
return tensorflow::errors::InvalidArgument("Boom!");
}
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(MaybeFail);
// Base class for transformers that change the name of the component, based on
// its current name.
class ChangeNameBase : public ComponentTransformer {
public:
// Creates a transformer that changes the component name from |from| to |to|.
explicit ChangeNameBase(const string &from, const string &to)
: from_(from), to_(to) {}
// Implements ComponentTransformer.
tensorflow::Status Transform(const string &,
ComponentSpec *component_spec) override {
if (component_spec->name() == from_) component_spec->set_name(to_);
return tensorflow::Status::OK();
}
private:
// Component name to look for.
const string from_;
// Component name to change to.
const string to_;
};
// These will convert chain1 => chain2 => chain3.
class Chain1To2 : public ChangeNameBase {
public:
Chain1To2() : ChangeNameBase("chain1", "chain2") {}
};
class Chain2To3 : public ChangeNameBase {
public:
Chain2To3() : ChangeNameBase("chain2", "chain3") {}
};
// Adds "." to the name of the component, if it begins with "cycle".
class Cycle : public ComponentTransformer {
public:
// Implements ComponentTransformer.
tensorflow::Status Transform(const string &,
ComponentSpec *component_spec) override {
if (component_spec->name().substr(0, 5) == "cycle") {
component_spec->mutable_name()->append(".");
}
return tensorflow::Status::OK();
}
};
// Intentionally registered out of order to exercise sorting on registered name.
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(Chain1To2);
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(Chain2To3);
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(Cycle);
// Arbitrary bogus path.
constexpr char kInvalidPath[] = "path/to/some/invalid/file";
// Returns a unique temporary directory for tests.
string GetUniqueTemporaryDir() {
static int counter = 0;
const string output_dir =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(),
tensorflow::strings::StrCat("tmp_", counter++));
TF_CHECK_OK(tensorflow::Env::Default()->RecursivelyCreateDir(output_dir));
return output_dir;
}
// Returns a MasterSpec parsed from the |text|.
MasterSpec ParseSpec(const string &text) {
MasterSpec master_spec;
CHECK(TextFormat::ParseFromString(text, &master_spec));
return master_spec;
}
// Tests that TransformComponents() fails if the input master spec path is
// invalid.
TEST(TransformComponentsTest, InvalidInputMasterSpecPath) {
const string temp_dir = GetUniqueTemporaryDir();
const string output_path = tensorflow::io::JoinPath(temp_dir, "output");
EXPECT_FALSE(TransformComponents(kInvalidPath, output_path).ok());
}
// Tests that TransformComponents() fails if the output master spec path is
// invalid.
TEST(TransformComponentsTest, InvalidOutputMasterSpecPath) {
const string temp_dir = GetUniqueTemporaryDir();
const string input_path = tensorflow::io::JoinPath(temp_dir, "input");
const MasterSpec empty_spec;
TF_ASSERT_OK(tensorflow::WriteTextProto(tensorflow::Env::Default(),
input_path, empty_spec));
EXPECT_FALSE(TransformComponents(input_path, kInvalidPath).ok());
}
// Tests that TransformComponents() fails if one of the ComponentTransformers
// fails.
TEST(TransformComponentsTest, FailingComponentTransformer) {
const string temp_dir = GetUniqueTemporaryDir();
const string input_path = tensorflow::io::JoinPath(temp_dir, "input");
const string output_path = tensorflow::io::JoinPath(temp_dir, "output");
const MasterSpec input_spec = ParseSpec(R"(
component {
component_builder { registered_name:'foo' }
}
component {
component_builder { registered_name:'fail' }
}
)");
TF_ASSERT_OK(tensorflow::WriteTextProto(tensorflow::Env::Default(),
input_path, input_spec));
EXPECT_THAT(TransformComponents(input_path, output_path),
test::IsErrorWithSubstr("Boom!"));
}
// Tests that TransformComponents() properly applies all transformations.
TEST(TransformComponentsTest, Success) {
const string temp_dir = GetUniqueTemporaryDir();
const string input_path = tensorflow::io::JoinPath(temp_dir, "input");
const string output_path = tensorflow::io::JoinPath(temp_dir, "output");
const MasterSpec input_spec = ParseSpec(R"(
component {
name:'chain1'
component_builder { registered_name:'foo' }
}
component {
name:'irrelevant'
component_builder { registered_name:'bar' }
}
)");
TF_ASSERT_OK(tensorflow::WriteTextProto(tensorflow::Env::Default(),
input_path, input_spec));
TF_ASSERT_OK(TransformComponents(input_path, output_path));
MasterSpec actual_spec;
TF_ASSERT_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(),
output_path, &actual_spec));
const MasterSpec expected_spec = ParseSpec(R"(
component {
name:'chain3'
component_builder { registered_name:'foo' }
}
component {
name:'irrelevant'
component_builder { registered_name:'bar' }
}
)");
EXPECT_THAT(actual_spec, test::EqualsProto(expected_spec));
}
// Tests that ComponentTransformer::ApplyAll() makes the expected modifications,
// including chained modifications.
TEST(ComponentTransformerTest, ApplyAllSuccess) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("foo");
component_spec.set_name("chain1");
ComponentSpec modified_spec = component_spec;
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
modified_spec.set_name("chain3");
EXPECT_THAT(component_spec, test::EqualsProto(modified_spec));
}
// Tests that ComponentTransformer::ApplyAll() limits the number of iterations.
TEST(ComponentTransformerTest, ApplyAllLimitIterations) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("foo");
component_spec.set_name("cycle");
EXPECT_THAT(ComponentTransformer::ApplyAll(&component_spec),
test::IsErrorWithSubstr("Failed to converge"));
}
// Tests that ComponentTransformer::ApplyAll() fails if one of the
// ComponentTransformers fails.
TEST(ComponentTransformerTest, ApplyAllFailure) {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name("fail");
EXPECT_THAT(ComponentTransformer::ApplyAll(&component_spec),
test::IsErrorWithSubstr("Boom!"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/conversion.h"
#include <memory>
#include <utility>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/array_variable_store_builder.h"
#include "dragnn/runtime/master.h"
#include "dragnn/runtime/trained_model_variable_store.h"
#include "dragnn/runtime/variable_store.h"
#include "dragnn/runtime/variable_store_wrappers.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status ConvertVariables(const string &saved_model_dir,
const string &master_spec_path,
const string &variables_spec_path,
const string &variables_data_path) {
// Read the trained model.
auto *trained_model_store = new TrainedModelVariableStore();
std::unique_ptr<VariableStore> store(trained_model_store);
TF_RETURN_IF_ERROR(trained_model_store->Reset(saved_model_dir));
// Wrap the TF store to enable averaging and capturing.
//
// The averaging wrapper currently needs to allow fall-back versions, since
// derived parameters (used for the LSTM network) read averaged versions via
// their TensorFlow-internal logic.
//
// The capturing wrapper must be the outermost, so variable names, formats,
// and content are captured exactly as the components would receive them.
store.reset(new TryAveragedVariableStoreWrapper(std::move(store), true));
store.reset(new FlexibleMatrixVariableStoreWrapper(std::move(store)));
auto *capturing_store = new CaptureUsedVariableStoreWrapper(std::move(store));
store.reset(capturing_store);
// Initialize a master using the wrapped store. This should populate the
// |capturing_store| with all of the used variables.
MasterSpec master_spec;
TF_RETURN_IF_ERROR(tensorflow::ReadTextProto(tensorflow::Env::Default(),
master_spec_path, &master_spec));
Master master;
TF_RETURN_IF_ERROR(master.Initialize(master_spec, std::move(store)));
// Convert the used variables into an ArrayVariableStore.
ArrayVariableStoreSpec variables_spec;
string variables_data;
TF_RETURN_IF_ERROR(ArrayVariableStoreBuilder::Build(
capturing_store->variables(), &variables_spec, &variables_data));
// Write the converted variables.
TF_RETURN_IF_ERROR(tensorflow::WriteTextProto(
tensorflow::Env::Default(), variables_spec_path, variables_spec));
TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(
tensorflow::Env::Default(), variables_data_path, variables_data));
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for converting pre-trained models into a production-ready format.
#ifndef DRAGNN_RUNTIME_CONVERSION_H_
#define DRAGNN_RUNTIME_CONVERSION_H_
#include <string>
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Converts selected variables from a pre-trained TF model into the format used
// by the ArrayVariableStore. Only converts the variables required to run the
// components in a given MasterSpec.
//
// Inputs:
// saved_model_dir: TF SavedModel directory.
// master_spec_path: Text-format MasterSpec proto.
//
// Outputs:
// variables_spec_path: Text-format ArrayVariableStoreSpec proto.
// variables_data_path: Byte array representing an ArrayVariableStore.
//
// This function will instantiate and initialize a Master using the MasterSpec
// at the |master_path|, so any registered components used by that MasterSpec
// must be linked into the binary.
//
// Side note: This function has a file-path-based API so it can be easily
// wrapped in a stand-alone binary.
tensorflow::Status ConvertVariables(const string &saved_model_dir,
const string &master_spec_path,
const string &variables_spec_path,
const string &variables_data_path);
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_CONVERSION_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/conversion.h"
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
class ConvertVariablesTest : public ::testing::Test {
protected:
// The input files.
const string kSavedModelDir = tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/runtime/testdata/rnn_tagger");
const string kMasterSpecPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(),
"dragnn/runtime/testdata/rnn_tagger/assets.extra/master_spec");
// Writable output files.
const string kVariablesSpecPath =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "variables_spec");
const string kVariablesDataPath =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "variables_data");
// Bogus file for tests.
const string kInvalidPath = "path/to/some/invalid/file";
// Expected output files.
const string kExpectedVariablesSpecPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(),
"dragnn/runtime/testdata/conversion_output_variables_spec");
const string kExpectedVariablesDataPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(),
"dragnn/runtime/testdata/conversion_output_variables_data");
// Local relative paths to the output files.
const string kLocalVariablesSpecPath =
"dragnn/runtime/testdata/"
"conversion_output_variables_spec";
const string kLocalVariablesDataPath =
"dragnn/runtime/testdata/"
"conversion_output_variables_data";
};
// Tests that the conversion fails if the saved model is invalid.
TEST_F(ConvertVariablesTest, InvalidSavedModel) {
EXPECT_FALSE(ConvertVariables(kInvalidPath, kMasterSpecPath,
kVariablesSpecPath, kVariablesDataPath)
.ok());
}
// Tests that the conversion fails if the master spec is invalid.
TEST_F(ConvertVariablesTest, InvalidMasterSpec) {
EXPECT_FALSE(ConvertVariables(kSavedModelDir, kInvalidPath,
kVariablesSpecPath, kVariablesDataPath)
.ok());
}
// Tests that the conversion fails if the variables spec is invalid.
TEST_F(ConvertVariablesTest, InvalidVariablesSpec) {
EXPECT_FALSE(ConvertVariables(kSavedModelDir, kMasterSpecPath, kInvalidPath,
kVariablesDataPath)
.ok());
}
// Tests that the conversion fails if the variables data is invalid.
TEST_F(ConvertVariablesTest, InvalidVariablesData) {
EXPECT_FALSE(ConvertVariables(kSavedModelDir, kMasterSpecPath,
kVariablesSpecPath, kInvalidPath)
.ok());
}
// Tests that the conversion succeeds on the pre-trained inputs and reproduces
// expected outputs.
TEST_F(ConvertVariablesTest, RegressionTest) {
TF_EXPECT_OK(ConvertVariables(kSavedModelDir, kMasterSpecPath,
kVariablesSpecPath, kVariablesDataPath));
ArrayVariableStoreSpec actual_variables_spec;
string actual_variables_data;
TF_ASSERT_OK(tensorflow::ReadTextProto(
tensorflow::Env::Default(), kVariablesSpecPath, &actual_variables_spec));
TF_ASSERT_OK(tensorflow::ReadFileToString(
tensorflow::Env::Default(), kVariablesDataPath, &actual_variables_data));
if (false) {
TF_ASSERT_OK(tensorflow::WriteTextProto(tensorflow::Env::Default(),
kLocalVariablesSpecPath,
actual_variables_spec));
TF_ASSERT_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
kLocalVariablesDataPath,
actual_variables_data));
} else {
ArrayVariableStoreSpec expected_variables_spec;
string expected_variables_data;
TF_ASSERT_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(),
kExpectedVariablesSpecPath,
&expected_variables_spec));
TF_ASSERT_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(),
kExpectedVariablesDataPath,
&expected_variables_data));
EXPECT_THAT(actual_variables_spec,
test::EqualsProto(expected_variables_spec));
EXPECT_EQ(actual_variables_data, expected_variables_data);
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Tool for converting trained models for use in the runtime.
#include <set>
#include <string>
#include <vector>
#include "dragnn/runtime/component_transformation.h"
#include "dragnn/runtime/conversion.h"
#include "dragnn/runtime/myelin/myelination.h"
#include "dragnn/runtime/xla/xla_compilation.h"
#include "syntaxnet/base.h"
#include "sling/base/flags.h" // TF does not support flags, but SLING does
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
DEFINE_string(saved_model_dir, "", "Path to TF SavedModel directory.");
DEFINE_string(master_spec_file, "", "Path to text-format MasterSpec proto.");
DEFINE_string(
myelin_components, "",
"Comma-delimited list of components to compile using Myelin, if any");
DEFINE_string(
xla_components, "",
"Comma-delimited list of components to compile using XLA, if any.");
DEFINE_string(xla_model_name, "", "Name to apply to XLA-based components.");
DEFINE_string(
output_dir, "",
"Path to an output directory. This will be filled with the following "
"files and subdirectories. MasterSpec: Converted text-format MasterSpec "
"proto. ArrayVariableStoreSpec: Converted text-format variable spec. "
"ArrayVariableStoreData: Converted variable data. myelin/*: Compiled "
"Myelin components, if any. xla/*: Compiled XLA components, if any.");
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Splits the |list| on commas and returns the set of elements.
std::set<string> Split(const string &list) {
const std::vector<string> elements = tensorflow::str_util::Split(list, ",");
return std::set<string>(elements.begin(), elements.end());
}
// Creates an empty directory at the |path|. If the directory exists, it is
// recursively deleted first.
void CreateEmptyDir(const string &path) {
// Ensure that the directory exists; otherwise DeleteRecursively() may fail.
TF_QCHECK_OK(tensorflow::Env::Default()->RecursivelyCreateDir(path));
int64 unused_undeleted_files, unused_undeleted_dirs;
TF_QCHECK_OK(tensorflow::Env::Default()->DeleteRecursively(
path, &unused_undeleted_files, &unused_undeleted_dirs));
TF_QCHECK_OK(tensorflow::Env::Default()->RecursivelyCreateDir(path));
}
// Performs Myelin compilation on the MasterSpec at |master_spec_path|, if
// requested. Returns the path to the converted or original MasterSpec.
string CompileMyelin(const string &master_spec_path) {
const std::set<string> components = Split(FLAGS_myelin_components);
if (components.empty()) return master_spec_path;
LOG(INFO) << "Compiling Myelin in MasterSpec " << master_spec_path;
const string dir = tensorflow::io::JoinPath(FLAGS_output_dir, "myelin");
CreateEmptyDir(dir);
TF_QCHECK_OK(
MyelinateCells(FLAGS_saved_model_dir, master_spec_path, components, dir));
return tensorflow::io::JoinPath(dir, "master-spec");
}
// Performs XLA compilation on the MasterSpec at |master_spec_path|, if
// requested. Returns the path to the converted or original MasterSpec.
string CompileXla(const string &master_spec_path) {
const std::set<string> components = Split(FLAGS_xla_components);
if (components.empty()) return master_spec_path;
LOG(INFO) << "Compiling XLA in MasterSpec " << master_spec_path;
const string dir = tensorflow::io::JoinPath(FLAGS_output_dir, "xla");
CreateEmptyDir(dir);
TF_QCHECK_OK(XlaCompileCells(FLAGS_saved_model_dir, master_spec_path,
components, FLAGS_xla_model_name, dir));
return tensorflow::io::JoinPath(dir, "master-spec");
}
// Transforms the MasterSpec at |master_spec_path|, and returns the path to the
// transformed MasterSpec.
string Transform(const string &master_spec_path) {
LOG(INFO) << "Transforming MasterSpec " << master_spec_path;
const string output_master_spec_path =
tensorflow::io::JoinPath(FLAGS_output_dir, "MasterSpec");
TF_QCHECK_OK(TransformComponents(master_spec_path, output_master_spec_path));
return output_master_spec_path;
}
// Performs final variable conversion on the MasterSpec at |master_spec_path|.
void Convert(const string &master_spec_path) {
LOG(INFO) << "Converting MasterSpec " << master_spec_path;
const string variables_data_path =
tensorflow::io::JoinPath(FLAGS_output_dir, "ArrayVariableStoreData");
const string variables_spec_path =
tensorflow::io::JoinPath(FLAGS_output_dir, "ArrayVariableStoreSpec");
TF_QCHECK_OK(ConvertVariables(FLAGS_saved_model_dir, master_spec_path,
variables_spec_path, variables_data_path));
}
// Implements main().
void Main() {
CreateEmptyDir(FLAGS_output_dir);
string master_spec_path = FLAGS_master_spec_file;
master_spec_path = CompileMyelin(master_spec_path);
master_spec_path = CompileXla(master_spec_path);
master_spec_path = Transform(master_spec_path);
Convert(master_spec_path);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
int main(int argc, char **argv) {
sling::Flag::ParseCommandLineFlags(&argc, argv, true);
syntaxnet::dragnn::runtime::Main();
return 0;
}
#!/bin/bash
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Test for converter tool. To update the testdata, run the test with a single
# command-line argument specifying the path to the testdata directory.
set -e
set -u
# Infer the location of the data dependencies.
if [[ -d "${BASH_SOURCE[0]}.runfiles" ]]; then
# Use the ".runfiles" directory if available (this typically happens when
# running manually). SyntaxNet does not specify a workspace name, so the
# runfiles are placed in ".runfiles/__main__". If SyntaxNet is configured
# with a workspace name, then change "__main__" to that name. See
# https://github.com/bazelbuild/bazel/wiki/Updating-the-runfiles-tree-structure
RUNFILES="${BASH_SOURCE[0]}.runfiles/__main__"
else
# Otherwise, use this recipe borrowed from
# https://github.com/bazelbuild/bazel/blob/7d265e07e7a1e37f04d53342710e4f21d9ee8083/examples/shell/test.sh#L21
# shellcheck disable=SC2091
RUNFILES="${RUNFILES:-"$("$(cd "$(dirname "${BASH_SOURCE[0]}")")"; pwd)"}"
fi
readonly RUNFILES
readonly RUNTIME="${RUNFILES}/dragnn/runtime"
readonly CONVERTER="${RUNTIME}/converter"
readonly SAVED_MODEL="${RUNTIME}/testdata/rnn_tagger"
readonly MASTER_SPEC="${SAVED_MODEL}/assets.extra/master_spec"
readonly EXPECTED="${RUNTIME}/testdata/converter_output"
readonly OUTPUT="${TEST_TMPDIR:-/tmp/$$}/converted"
# Fails the test with a message.
function fail() {
echo "$@" 1>&2 # print to stderr
exit 1
}
# Asserts that a file exists.
function assert_file_exists() {
if [[ ! -f "$1" ]]; then
fail "missing file: $1"
fi
}
# Asserts that two files have the same content.
function assert_file_content_eq() {
assert_file_exists "$1"
assert_file_exists "$2"
if ! diff -u "$1" "$2"; then
fail "files differ: $1 $2"
fi
}
rm -rf "${OUTPUT}"
"${CONVERTER}" \
--saved_model_dir="${SAVED_MODEL}" \
--master_spec_file="${MASTER_SPEC}" \
--output_dir="${OUTPUT}" \
--logtostderr
for file in \
'MasterSpec' \
'ArrayVariableStoreData' \
'ArrayVariableStoreSpec'; do
if [[ $# -gt 0 ]]; then
# Update expected output.
rm -f "$1/${file}"
cp -f "${OUTPUT}/${file}" "$1/${file}"
else
# Compare to expected output.
assert_file_content_eq "${OUTPUT}/${file}" "${EXPECTED}/${file}"
fi
done
rm -rf "${OUTPUT}"
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// The DynamicComponent is the runtime analogue of the DynamicComponentBuilder
// in the Python codebase. The role of the DynamicComponent is to manage the
// loop over transition steps, including:
// * Allocating stepwise memory for network states and operands.
// * Performing some computation at each step.
// * Advancing the transition state until terminal.
//
// Note that the number of transition taken on any given evaluation of the
// DynamicComponent cannot be determined in advance.
//
// The core computational work is delegated to a NetworkUnit, which is evaluated
// at each transition step. This makes the DynamicComponent flexible, since it
// can be applied to any NetworkUnit implementation, but it can be significantly
// more efficient to use a task-specific component implementation. For example,
// the "shift-only" transition system merely scans the input tokens, and in that
// case we could replace the incremental loop with a "bulk" computation.
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Performs an incremental computation, one transition at a time.
class DynamicComponent : public Component {
protected:
// Implements Component.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override;
// This class is intended to support all DynamicComponent layers. We currently
// prefer to return `true` here and throw errors in Initialize() if a
// particular feature is not supported.
bool Supports(const ComponentSpec &spec,
const string &normalized_builder_name) const override {
return normalized_builder_name == "DynamicComponent";
}
// This class is not optimized, so any other supported subclasses of Component
// should be preferred.
bool PreferredTo(const Component &other) const override { return false; }
private:
// Name of this component.
string name_;
// Network unit that produces logits.
std::unique_ptr<NetworkUnit> network_unit_;
// Whether the transition system is deterministic.
bool deterministic_ = false;
// Handle to the network unit logits. Valid iff |deterministic_| is false.
LayerHandle<float> logits_handle_;
};
tensorflow::Status DynamicComponent::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
name_ = component_spec.name();
if (!component_spec.attention_component().empty()) {
return tensorflow::errors::Unimplemented("Attention is not supported");
}
TF_RETURN_IF_ERROR(NetworkUnit::CreateOrError(
NetworkUnit::GetClassName(component_spec), &network_unit_));
TF_RETURN_IF_ERROR(network_unit_->Initialize(component_spec, variable_store,
network_state_manager,
extension_manager));
// Logits are unnecesssary when the component is deterministic.
deterministic_ = TransitionSystemTraits(component_spec).is_deterministic;
if (!deterministic_) {
const string logits_name = network_unit_->GetLogitsName();
if (logits_name.empty()) {
return tensorflow::errors::InvalidArgument(
"Network unit does not produce logits: ",
component_spec.network_unit().ShortDebugString());
}
size_t dimension = 0;
TF_RETURN_IF_ERROR(network_state_manager->LookupLayer(
name_, logits_name, &dimension, &logits_handle_));
if (dimension != component_spec.num_actions()) {
return tensorflow::errors::InvalidArgument(
"Dimension mismatch between network unit logits (", dimension,
") and ComponentSpec.num_actions (", component_spec.num_actions(),
") in component '", name_, "'");
}
}
return tensorflow::Status::OK();
}
// No batches or beams.
constexpr int kNumItems = 1;
tensorflow::Status DynamicComponent::Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const {
NetworkStates &network_states = session_state->network_states;
for (size_t step_index = 0; !compute_session->IsTerminal(name_);
++step_index) {
network_states.AddStep();
TF_RETURN_IF_ERROR(
network_unit_->Evaluate(step_index, session_state, compute_session));
// If the component is deterministic, take the oracle transition instead of
// predicting the next transition using the logits.
if (deterministic_) {
compute_session->AdvanceFromOracle(name_);
} else {
// AddStep() may invalidate the logits (due to reallocation), so the layer
// lookup cannot be hoisted out of this loop.
const Vector<float> logits(
network_states.GetLayer(logits_handle_).row(step_index));
if (!compute_session->AdvanceFromPrediction(name_, logits.data(),
kNumItems, logits.size())) {
return tensorflow::errors::Internal(
"Error in ComputeSession::AdvanceFromPrediction()");
}
}
}
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT(DynamicComponent);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <algorithm>
#include <limits>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::_;
using ::testing::Return;
constexpr size_t kStepsDim = 41;
constexpr size_t kNumSteps = 23;
// Fills each row of its logits with the step index.
class StepsNetwork : public NetworkUnit {
public:
// Implements NetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return network_state_manager->AddLayer("steps", kStepsDim, &handle_);
}
string GetLogitsName() const override { return "steps"; }
tensorflow::Status Evaluate(size_t step_index, SessionState *session_state,
ComputeSession *compute_session) const override {
const MutableVector<float> logits =
session_state->network_states.GetLayer(handle_).row(step_index);
for (float &logit : logits) logit = step_index;
return tensorflow::Status::OK();
}
private:
// Handle to the logits layer.
LayerHandle<float> handle_;
};
DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT(StepsNetwork);
// As above, but does not report a logits layer.
class NoLogitsNetwork : public StepsNetwork {
public:
// Implements NetworkUnit.
string GetLogitsName() const override { return ""; }
};
DRAGNN_RUNTIME_REGISTER_NETWORK_UNIT(NoLogitsNetwork);
class DynamicComponentTest : public NetworkTestBase {
protected:
// Creates a component, initializes it based on the |component_spec_text| and
// |network_unit_name|, and evaluates it. On error, returns non-OK.
tensorflow::Status Run(const string &component_spec_text,
const string &network_unit_name) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
component_spec.set_name(kTestComponentName);
component_spec.mutable_network_unit()->set_registered_name(
network_unit_name);
// Neither DynamicComponent nor the test networks use linked embeddings, so
// a trivial network suffices.
AddComponent(kTestComponentName);
TF_RETURN_IF_ERROR(
Component::CreateOrError("DynamicComponent", &component_));
TF_RETURN_IF_ERROR(component_->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
network_states_.Reset(&network_state_manager_);
StartComponent(0); // DynamicComponent will add steps
session_state_.extensions.Reset(&extension_manager_);
TF_RETURN_IF_ERROR(
component_->Evaluate(&session_state_, &compute_session_, nullptr));
steps_ = GetLayer(kTestComponentName, "steps");
return tensorflow::Status::OK();
}
std::unique_ptr<Component> component_;
Matrix<float> steps_;
};
// Tests that DynamicComponent fails if the spec uses attention.
TEST_F(DynamicComponentTest, UnsupportedAttention) {
EXPECT_THAT(Run("attention_component: 'foo'", "NoLogitsNetwork"),
test::IsErrorWithSubstr("Attention is not supported"));
}
// Tests that DynamicComponent fails if the network does not produce logits.
TEST_F(DynamicComponentTest, NoLogits) {
EXPECT_THAT(Run("", "NoLogitsNetwork"),
test::IsErrorWithSubstr("Network unit does not produce logits"));
}
// Tests that DynamicComponent fails if the logits do not have the required
// dimension.
TEST_F(DynamicComponentTest, MismatchedLogitsDimension) {
EXPECT_THAT(
Run("num_actions: 42", "StepsNetwork"),
test::IsErrorWithSubstr("Dimension mismatch between network unit logits "
"(41) and ComponentSpec.num_actions (42)"));
}
// Tests that DynamicComponent fails if ComputeSession::AdvanceFromPrediction()
// returns false.
TEST_F(DynamicComponentTest, FailToAdvanceFromPrediction) {
EXPECT_CALL(compute_session_, IsTerminal(_)).WillRepeatedly(Return(false));
EXPECT_CALL(compute_session_, AdvanceFromPrediction(_, _, _, _))
.WillOnce(Return(false));
EXPECT_THAT(Run("num_actions: 41", "StepsNetwork"),
test::IsErrorWithSubstr(
"Error in ComputeSession::AdvanceFromPrediction()"));
}
// Tests that DynamicComponent evaluates its network unit once per transition,
// each time passing the proper step index.
TEST_F(DynamicComponentTest, Steps) {
SetupTransitionLoop(kNumSteps);
// Accept |num_steps| transition steps.
EXPECT_CALL(compute_session_, AdvanceFromPrediction(_, _, _, _))
.Times(kNumSteps)
.WillRepeatedly(Return(true));
TF_ASSERT_OK(Run("num_actions: 41", "StepsNetwork"));
ASSERT_EQ(steps_.num_rows(), kNumSteps);
for (size_t step_index = 0; step_index < kNumSteps; ++step_index) {
ExpectVector(steps_.row(step_index), kStepsDim, step_index);
}
}
// Tests that DynamicComponent calls ComputeSession::AdvanceFromOracle() and
// does not use logits when the component is deterministic.
TEST_F(DynamicComponentTest, Determinstic) {
SetupTransitionLoop(kNumSteps);
// Take the oracle transition instead of predicting from logits.
EXPECT_CALL(compute_session_, AdvanceFromOracle(_)).Times(kNumSteps);
TF_EXPECT_OK(Run("num_actions: 1", "NoLogitsNetwork"));
// The NoLogitsNetwork still produces the "steps" layer, even if it does not
// mark them as its logits.
ASSERT_EQ(steps_.num_rows(), kNumSteps);
for (size_t step_index = 0; step_index < kNumSteps; ++step_index) {
ExpectVector(steps_.row(step_index), kStepsDim, step_index);
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment