Commit edea2b67 authored by Terry Koo's avatar Terry Koo
Browse files

Remove runtime because reasons.

parent a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_compilation.h"
#include <map>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/trained_model.h"
#include "dragnn/runtime/xla/xla_cell_converter.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Updates the Component subclass in the |component_spec| to an XLA-based
// version. On error, returns non-OK and modifies nothing.
tensorflow::Status XlaCompileComponentSubclass(ComponentSpec *component_spec) {
const string subclass = GetNormalizedComponentBuilderName(*component_spec);
if (subclass != "DynamicComponent") {
return tensorflow::errors::Unimplemented(
"No XLA-based version of Component subclass '", subclass, "'");
}
// By convention, the XLA-based version of "FooComponent" should be named
// "XlaFooComponent".
component_spec->mutable_component_builder()->set_registered_name(
tensorflow::strings::StrCat("Xla", subclass));
return tensorflow::Status::OK();
}
// Appends the list of component specs in the |master_spec| whose names match
// |component_names| to |matching_components|. On error, returns non-OK.
tensorflow::Status GetMatchingComponentSpecs(
const std::set<string> &component_names, MasterSpec *master_spec,
std::vector<ComponentSpec *> *matching_components) {
// Index the components in the |master_spec| by name.
std::map<string, ComponentSpec *> components;
for (ComponentSpec &component_spec : *master_spec->mutable_component()) {
if (!components.emplace(component_spec.name(), &component_spec).second) {
return tensorflow::errors::InvalidArgument("Duplicate component name: ",
component_spec.name());
}
}
// Append the components named in the |component_names|.
for (const string &component_name : component_names) {
if (components.find(component_name) == components.end()) {
return tensorflow::errors::InvalidArgument("Unknown component name: ",
component_name);
}
matching_components->push_back(components[component_name]);
}
return tensorflow::Status::OK();
}
} // namespace
tensorflow::Status XlaCompileCells(const string &saved_model_dir,
const string &master_spec_path,
const std::set<string> &component_names,
const string &model_name,
const string &output_dir) {
MasterSpec master_spec;
TF_RETURN_IF_ERROR(tensorflow::ReadTextProto(tensorflow::Env::Default(),
master_spec_path, &master_spec));
std::vector<ComponentSpec *> components;
TF_RETURN_IF_ERROR(
GetMatchingComponentSpecs(component_names, &master_spec, &components));
// Returns the path to the output frozen GraphDef file for the
// |component_spec|.
const auto get_frozen_graph_def_path =
[&](const ComponentSpec &component_spec) {
return tensorflow::io::JoinPath(
output_dir,
tensorflow::strings::StrCat(component_spec.name(),
kFrozenGraphDefResourceFileSuffix));
};
// Perform some changes to the MasterSpec first, to catch issues before
// loading the trained models, which is slow.
for (ComponentSpec *component_spec : components) {
// Add a resource for the frozen GraphDef file to each component. The file
// will be created in a second pass, after loading the trained model.
TF_RETURN_IF_ERROR(AddFrozenGraphDefResource(
get_frozen_graph_def_path(*component_spec), component_spec));
// Replace the Component subclass with an XLA-based version.
TF_RETURN_IF_ERROR(XlaCompileComponentSubclass(component_spec));
// Set embedding_dim=-1 for all channels.
for (auto &fixed_channel : *component_spec->mutable_fixed_feature()) {
fixed_channel.set_embedding_dim(-1);
}
for (auto &linked_channel : *component_spec->mutable_linked_feature()) {
linked_channel.set_embedding_dim(-1);
}
}
// Create output directory which contains the new master spec and
// the frozen graphs.
TF_RETURN_IF_ERROR(
tensorflow::Env::Default()->RecursivelyCreateDir(output_dir));
// Convert each component into a frozen GraphDef and write it. Also may
// add a CompilationSpec.
TrainedModel trained_model;
TF_RETURN_IF_ERROR(trained_model.Reset(saved_model_dir));
for (ComponentSpec *component_spec : components) {
tensorflow::GraphDef frozen_graph_def;
CellSubgraphSpec cell_subgraph_spec;
TF_RETURN_IF_ERROR(
XlaCellConverter::Convert(component_spec->name(), trained_model,
&frozen_graph_def, &cell_subgraph_spec));
TF_RETURN_IF_ERROR(SaveFrozenGraphDef(
get_frozen_graph_def_path(*component_spec), frozen_graph_def));
if (!model_name.empty()) {
auto *compilation_spec = component_spec->MutableExtension(
CompilationSpec::component_spec_extension);
compilation_spec->set_model_name(model_name);
*compilation_spec->mutable_cell_subgraph_spec() = cell_subgraph_spec;
}
}
// Write the updated MasterSpec.
TF_RETURN_IF_ERROR(tensorflow::WriteTextProto(
tensorflow::Env::Default(),
tensorflow::io::JoinPath(output_dir, "master-spec"), master_spec));
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for modifying pre-trained models to use XLA.
#ifndef DRAGNN_RUNTIME_XLA_XLA_COMPILATION_H_
#define DRAGNN_RUNTIME_XLA_XLA_COMPILATION_H_
#include <set>
#include <string>
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Modifies a DRAGNN model to use XLA.
//
// Loads a TF SavedModel from the |saved_model_dir| and a text-format MasterSpec
// from the |master_spec_path|. Converts each component in |component_names|
// into a frozen TF GraphDef (see xla_cell_converter.h) and writes the results
// to the |output_dir| as files "<output_dir>/<component_name>-frozen".
// Modifies the relevant ComponentSpecs in the MasterSpec to use XLA as
// described below, and writes it to "<output_dir>/master-spec".
//
// MasterSpec modifications:
// * Adds a resource to each ComponentSpec that points at the relevant
// frozen GraphDef file in the |output_dir|.
// * Replaces the Component subclass specified in each ComponentSpec with the
// XLA-based equivalent, which should be named "Xla<subclass_name>";
// e.g., XlaDynamicComponent.
// * If |model_name| is non-empty, adds a CompilationSpec extension to each
// ComponentSpec with |model_name| and its corresponding CellSubgraphSpec.
// This is necessary for XLA AOT compilation.
// * Sets FixedFeatureChannel.embedding_dim to -1 in all channels, because
// XLA takes feature IDs as input instead of fixed embedding sums.
// * Sets LinkedFeatureChannel.embedding_dim to -1 in all channels, because
// XLA handles the linked embedding matrix multiplication (if any) and
// always takes the original activation vector as input.
//
// On error, returns non-OK. Possible errors include:
// * Any file I/O or proto parsing error.
// * The MasterSpec has a duplicate component name.
// * One of the |component_names| does not match anything in the MasterSpec.
// * The MasterSpec already has XLA GraphDef resources.
// * One of the components is not supported by XLA.
// * Error raised by XlaCellConverter during conversion.
//
// Side note: This function has a file-path-based API so it can be easily
// wrapped in a stand-alone binary.
tensorflow::Status XlaCompileCells(const string &saved_model_dir,
const string &master_spec_path,
const std::set<string> &component_names,
const string &model_name,
const string &output_dir);
// TODO(googleuser): Add equivalent class for Myelinator.
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_XLA_XLA_COMPILATION_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_compilation.h"
#include <memory>
#include <string>
#include <utility>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Arbitrary bogus path.
constexpr char kInvalidPath[] = "path/to/some/invalid/file";
// Relative path to a MasterSpec.
constexpr char kMasterSpecPath[] =
"dragnn/runtime/testdata/rnn_tagger/assets.extra/master_spec";
// Relative path to a saved model.
constexpr char kSavedModelDir[] = "dragnn/runtime/testdata/rnn_tagger";
// Relative path to a directory containing expected output.
constexpr char kExpectedOutputDir[] =
"dragnn/runtime/xla/testdata/xla_compilation_output";
// Local relative path to the expected output directory.
constexpr char kLocalOutputDir[] =
"dragnn/runtime/xla/testdata/xla_compilation_output";
// Returns the set of components in the MasterSpec at |kMasterSpecPath|.
std::set<string> GetComponentNames() { return {"rnn", "tagger"}; }
// Returns the path to a test input denoted by the |relative_path|.
string GetInput(const string &relative_path) {
return tensorflow::io::JoinPath(test::GetTestDataPrefix(), relative_path);
}
// Returns a unique output directory for tests.
string GetUniqueOutputDir() {
static int counter = 0;
return tensorflow::io::JoinPath(
tensorflow::testing::TmpDir(),
tensorflow::strings::StrCat("output_", counter++));
}
// Compares the content of the file named |basename| in the |actual_output_dir|
// with the file |testname| in |kExpectedOutputDir|. Can also be modified to
// write the actual file content to |kLocalOutputDir|, for updating test
// expectations.
void CompareOrRewriteTestData(const string &actual_output_dir,
const string &basename, const string &testname) {
string actual_data;
TF_ASSERT_OK(tensorflow::ReadFileToString(
tensorflow::Env::Default(),
tensorflow::io::JoinPath(actual_output_dir, basename), &actual_data));
if (false) {
TF_ASSERT_OK(tensorflow::WriteStringToFile(
tensorflow::Env::Default(),
tensorflow::io::JoinPath(kLocalOutputDir, testname), actual_data));
} else {
string expected_data;
TF_ASSERT_OK(tensorflow::ReadFileToString(
tensorflow::Env::Default(),
GetInput(tensorflow::io::JoinPath(kExpectedOutputDir, testname)),
&expected_data));
// Note: EXPECT_EQ is avoided because printing the diff on failure
// leads to timeouts.
EXPECT_EQ(actual_data, expected_data);
EXPECT_TRUE(actual_data == expected_data)
<< "Actual and expected file contents differ for " << basename
<< "; (actual in " << actual_output_dir << ")";
}
}
// Compares the content of the file named |basename| in the |actual_output_dir|
// with the file with the same |basename| in |kExpectedOutputDir|. Can also be
// modified to write the actual file content to |kLocalOutputDir|, for updating
// test expectations.
void CompareOrRewriteTestData(const string &actual_output_dir,
const string &basename) {
CompareOrRewriteTestData(actual_output_dir, basename, basename);
}
// Reads a text-format MasterSpec from the |master_spec_path|, clears resource
// file patterns, and writes it back to the |master_spec_path|. The resource
// file patterns would otherwise cause spurious mismatches.
void ClearResourceFilePatterns(const string &master_spec_path) {
MasterSpec master_spec;
TF_ASSERT_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(),
master_spec_path, &master_spec));
for (ComponentSpec &component_spec : *master_spec.mutable_component()) {
for (Resource &resource : *component_spec.mutable_resource()) {
for (Part &part : *resource.mutable_part()) {
part.clear_file_pattern();
}
}
}
TF_ASSERT_OK(tensorflow::WriteTextProto(tensorflow::Env::Default(),
master_spec_path, master_spec));
}
// Tests that XlaCompileCells() fails if the saved model is invalid.
TEST(XlaCompileCellsTest, InvalidSavedModel) {
EXPECT_FALSE(XlaCompileCells(kInvalidPath, GetInput(kMasterSpecPath), {}, "",
GetUniqueOutputDir())
.ok());
}
// Tests that XlaCompileCells() fails if the master spec is invalid.
TEST(XlaCompileCellsTest, InvalidMasterSpec) {
EXPECT_FALSE(XlaCompileCells(GetInput(kSavedModelDir), kInvalidPath, {}, "",
GetUniqueOutputDir())
.ok());
}
// Tests that XlaCompileCells() fails if the MasterSpec contains a duplicate
// component.
TEST(XlaCompileCellsTest, DuplicateComponent) {
const string kSpec = "component { name:'foo' } component { name:'foo' }";
const string master_spec_path = tensorflow::io::JoinPath(
tensorflow::testing::TmpDir(), "master-spec-with-duplicate");
TF_ASSERT_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
master_spec_path, kSpec));
EXPECT_THAT(XlaCompileCells(GetInput(kSavedModelDir), master_spec_path, {},
"", GetUniqueOutputDir()),
test::IsErrorWithSubstr("Duplicate component name: foo"));
}
// Tests that XlaCompileCells() fails if one of the requested components does
// not appear in the MasterSpec.
TEST(XlaCompileCellsTest, FilterWithUnknownComponent) {
const string kSpec = "component { name:'foo' } component { name:'bar' }";
const string master_spec_path = tensorflow::io::JoinPath(
tensorflow::testing::TmpDir(), "master-spec-foo-bar");
TF_ASSERT_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
master_spec_path, kSpec));
EXPECT_THAT(XlaCompileCells(GetInput(kSavedModelDir), master_spec_path,
{"missing"}, "", GetUniqueOutputDir()),
test::IsErrorWithSubstr("Unknown component name: missing"));
}
// Tests that XlaCompileCells() fails if a component already has a frozen
// GraphDef.
TEST(XlaCompileCellsTest, AlreadyHasFrozenGraphDef) {
const string kSpec =
tensorflow::strings::StrCat("component { name: 'foo' resource { name: '",
kFrozenGraphDefResourceName, "' } }");
const string master_spec_path = tensorflow::io::JoinPath(
tensorflow::testing::TmpDir(), "master-spec-with-flows");
TF_ASSERT_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
master_spec_path, kSpec));
EXPECT_THAT(XlaCompileCells(GetInput(kSavedModelDir), master_spec_path,
{"foo"}, "", GetUniqueOutputDir()),
test::IsErrorWithSubstr(
"already contains a frozen TF GraphDef resource"));
}
// Tests that XlaCompileCells() fails on the wrong Component type.
TEST(XlaCompileCellsTest, WrongComponentType) {
const string kSpec =
"component { name: 'foo' component_builder { registered_name: "
"'WrongComponent' } }";
const string master_spec_path =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "master-spec");
TF_ASSERT_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
master_spec_path, kSpec));
EXPECT_THAT(
XlaCompileCells(GetInput(kSavedModelDir), master_spec_path, {"foo"}, "",
GetUniqueOutputDir()),
test::IsErrorWithSubstr(
"No XLA-based version of Component subclass 'WrongComponent'"));
}
// Tests that XlaCompileCells() succeeds on the pre-trained inputs and
// reproduces expected outputs.
TEST(XlaCompileCellsTest, RegressionTest) {
const string output_dir = GetUniqueOutputDir();
TF_ASSERT_OK(XlaCompileCells(GetInput(kSavedModelDir),
GetInput(kMasterSpecPath), GetComponentNames(),
"", output_dir));
ClearResourceFilePatterns(
tensorflow::io::JoinPath(output_dir, "master-spec"));
CompareOrRewriteTestData(output_dir, "master-spec");
for (const string &component_name : GetComponentNames()) {
const string graph_def_basename = tensorflow::strings::StrCat(
component_name, kFrozenGraphDefResourceFileSuffix);
CompareOrRewriteTestData(output_dir, graph_def_basename);
}
}
// Tests that XlaCompileCells() succeeds on the pre-trained inputs and
// reproduces expected outputs.
TEST(XlaCompileCellsTest, RegressionTestWithModelNameForAot) {
const string output_dir = GetUniqueOutputDir();
TF_ASSERT_OK(XlaCompileCells(GetInput(kSavedModelDir),
GetInput(kMasterSpecPath), GetComponentNames(),
"model_v1", output_dir));
ClearResourceFilePatterns(
tensorflow::io::JoinPath(output_dir, "master-spec"));
CompareOrRewriteTestData(output_dir, "master-spec", "master-spec-aot");
for (const string &component_name : GetComponentNames()) {
const string graph_def_basename = tensorflow::strings::StrCat(
component_name, kFrozenGraphDefResourceFileSuffix);
CompareOrRewriteTestData(output_dir, graph_def_basename);
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/xla/sequence_xla_dynamic_component_mixin.h"
#include "dragnn/runtime/xla/xla_dynamic_component_base.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// An XLA-based version of DynamicComponent using the XLA JIT API.
//
// It uses the XLA JIT API to compile the graph, and uses the frozen GraphDef
// referred to in the component spec.
class XlaDynamicComponent : public XlaDynamicComponentBase {
protected:
// Unlike other specializations, this component will only be active if the
// spec is explicitly modified to support XLA (and frozen graph resources are
// generated).
bool Supports(const ComponentSpec &spec,
const string &normalized_builder_name) const override {
return normalized_builder_name == "XlaDynamicComponent";
}
bool PreferredTo(const Component &other) const override { return false; }
// Gets the frozen GraphDef using the |component_spec| and compiles it.
// The |cell_subgraph_spec| contained within it is filled in. On error,
// returns non-OK.
tensorflow::Status InitializeFromComponentSpec(
const ComponentSpec &component_spec,
CellSubgraphSpec *cell_subgraph_spec) override;
const tensorflow::XlaCompiledCpuFunction::StaticData &XlaStaticData()
const override {
if (jit_ == nullptr) {
LOG(FATAL) << "XlaStaticData() called before "
"InitializeFromComponentSpec() for component "
<< name();
}
return jit_->StaticData();
}
private:
// Cell that contains the compiled code for this component.
std::unique_ptr<tensorflow::XlaJitCompiledCpuFunction> jit_;
};
tensorflow::Status XlaDynamicComponent::InitializeFromComponentSpec(
const ComponentSpec &component_spec, CellSubgraphSpec *cell_subgraph_spec) {
const Resource *resource = nullptr;
TF_RETURN_IF_ERROR(LookupFrozenGraphDefResource(component_spec, &resource));
const string &frozen_graph_def_path = resource->part(0).file_pattern();
tensorflow::GraphDef frozen_graph_def;
TF_RETURN_IF_ERROR(
LoadFrozenGraphDef(frozen_graph_def_path, &frozen_graph_def));
// Gets the CellSubgraphSpec from the frozen GraphDef and constructs
// the XLA Config required for compilation.
tensorflow::tf2xla::Config xla_config;
TF_RETURN_IF_ERROR(GetSpecAndMakeXlaConfig(frozen_graph_def,
cell_subgraph_spec, &xla_config));
// Compiles the cell.
TF_ASSIGN_OR_RETURN(
jit_, tensorflow::XlaJitCompiledCpuFunction::Compile(
frozen_graph_def, xla_config, xla::ExecutableBuildOptions()));
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT(XlaDynamicComponent);
// Sequence-based version of the above.
using SequenceXlaDynamicComponent =
SequenceXlaDynamicComponentMixin<XlaDynamicComponent>;
DRAGNN_RUNTIME_REGISTER_COMPONENT(SequenceXlaDynamicComponent);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_dynamic_component_base.h"
#include <string.h>
#include <algorithm>
#include "dragnn/protos/export.pb.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
constexpr char XlaDynamicComponentBase::kLogitsName[];
tensorflow::Status XlaDynamicComponentBase::Validate(
const ComponentSpec &component_spec) {
if (!component_spec.attention_component().empty()) {
return tensorflow::errors::Unimplemented("Attention is not supported");
}
for (const auto &fixed_feature : component_spec.fixed_feature()) {
if (fixed_feature.embedding_dim() != -1) {
return tensorflow::errors::InvalidArgument(
"XLA requires non-embedded fixed features");
}
}
for (const auto &linked_feature : component_spec.linked_feature()) {
if (linked_feature.embedding_dim() != -1) {
return tensorflow::errors::InvalidArgument(
"XLA requires non-multiplied linked features");
}
}
return tensorflow::Status::OK();
}
tensorflow::Status XlaDynamicComponentBase::ValidateTensor(
const string &name, const xla::PrimitiveType type, int dimension,
const xla::Shape &shape, int *elements_out) {
if (shape.element_type() != type) {
return tensorflow::errors::InvalidArgument(
"XLA tensor '", name, "' has wrong type ",
xla::PrimitiveType_Name(shape.element_type()), " (expected ",
xla::PrimitiveType_Name(type), ")");
}
int num_nontrivial_dims = 0;
int64 elements = 1;
for (int64 dim : shape.dimensions()) {
if (dim > 1) {
++num_nontrivial_dims;
elements *= dim;
}
}
if (num_nontrivial_dims > 1) {
return tensorflow::errors::InvalidArgument(
"XLA tensor has non-vector-like shape: '", name, "' ",
xla::ShapeUtil::HumanString(shape));
}
if (dimension >= 0 && elements != dimension) {
return tensorflow::errors::InvalidArgument(
"XLA input shape has the wrong dimension '", name, "' ",
xla::ShapeUtil::HumanString(shape), " (expected ", dimension, ")");
}
*elements_out = static_cast<int>(elements);
return tensorflow::Status::OK();
}
tensorflow::Status XlaDynamicComponentBase::LookupInputVector(
const string &name, const xla::PrimitiveType type, int dimension,
const tensorflow::XlaCompiledCpuFunction &instance,
InputHandle *input_handle) const {
input_handle->index = -1; // set to invalid if we error out
const int index = instance.LookupArgIndex(name);
if (index == -1 || index >= program_shape_->parameters_size()) {
return tensorflow::errors::NotFound("No XLA tensor named '", name, "'");
}
const xla::Shape &shape = program_shape_->parameters(index);
TF_RETURN_IF_ERROR(
ValidateTensor(name, type, dimension, shape, &input_handle->elements));
input_handle->index = index;
return tensorflow::Status::OK();
}
tensorflow::Status XlaDynamicComponentBase::LookupOutputVector(
const string &name, const xla::PrimitiveType type, int dimension,
const tensorflow::XlaCompiledCpuFunction &instance,
OutputHandle *output_handle) const {
output_handle->index = -1; // set to invalid if we error out
const int index = instance.LookupResultIndex(name);
if (index == -1) {
return tensorflow::errors::NotFound("No XLA tensor named '", name, "'");
}
const xla::Shape &result_shape = program_shape_->result();
if (result_shape.element_type() != xla::TUPLE) {
return tensorflow::errors::InvalidArgument("XLA output is not a tuple");
}
if (index >= result_shape.tuple_shapes_size()) {
return tensorflow::errors::InvalidArgument("Invalid XLA output index: ",
index);
}
const xla::Shape &shape = result_shape.tuple_shapes(index);
TF_RETURN_IF_ERROR(
ValidateTensor(name, type, dimension, shape, &output_handle->elements));
output_handle->index = index;
output_handle->bytes = xla::ShapeUtil::ByteSizeOf(shape);
return tensorflow::Status::OK();
}
tensorflow::Status XlaDynamicComponentBase::InitializeInputIds(
const tensorflow::XlaCompiledCpuFunction &instance) {
const int num_channels = fixed_embedding_manager_.num_channels();
input_ids_.resize(fixed_embedding_manager_.num_embeddings());
for (int channel_id = 0; channel_id < num_channels; ++channel_id) {
DCHECK(!fixed_embedding_manager_.is_embedded(channel_id));
const int channel_base = fixed_embedding_manager_.channel_base(channel_id);
const int channel_size = fixed_embedding_manager_.channel_size(channel_id);
for (int index = 0; index < channel_size; ++index) {
InputId &input = input_ids_[channel_base + index];
const string name = MakeXlaInputFixedFeatureIdName(channel_id, index);
TF_RETURN_IF_ERROR(
LookupInputVector(name, xla::S32, 1, instance, &input.id));
VLOG(1) << "Component '" << name_ << "' fixed channel " << channel_id
<< " index " << index << ": Added feature ID";
}
}
return tensorflow::Status::OK();
}
tensorflow::Status XlaDynamicComponentBase::InitializeInputLinks(
const tensorflow::XlaCompiledCpuFunction &instance) {
const int num_channels = linked_embedding_manager_.num_channels();
input_links_.resize(num_channels);
for (int channel_id = 0; channel_id < num_channels; ++channel_id) {
InputLink &input = input_links_[channel_id];
const int dimension = linked_embedding_manager_.embedding_dim(channel_id);
const string activations_name =
MakeXlaInputLinkedActivationVectorName(channel_id);
const string out_of_bounds_name =
MakeXlaInputLinkedOutOfBoundsIndicatorName(channel_id);
TF_RETURN_IF_ERROR(LookupInputVector(activations_name, xla::F32, dimension,
instance, &input.activations));
VLOG(1) << "Component '" << name_ << "' linked channel " << channel_id
<< ": Added activations";
// Allow NOT_FOUND, for linked embedding channels that don't multiply the
// input activations with an embedding matrix.
const tensorflow::Status status = LookupInputVector(
out_of_bounds_name, xla::F32, 1, instance, &input.out_of_bounds);
if (status.ok()) {
VLOG(1) << "Component '" << name_ << "' linked channel " << channel_id
<< ": Added out-of-bounds indicator for multiplication";
} else if (status.code() == tensorflow::error::NOT_FOUND) {
VLOG(1) << "Component '" << name_ << "' linked channel " << channel_id
<< ": No out-of-bounds indicator; not multiplied";
} else {
return status;
}
}
return tensorflow::Status::OK();
}
tensorflow::Status XlaDynamicComponentBase::InitializeInputRecurrences(
const CellSubgraphSpec &cell_subgraph_spec,
const NetworkStateManager &manager,
const tensorflow::XlaCompiledCpuFunction &instance) {
for (const auto &cell_input : cell_subgraph_spec.input()) {
if (cell_input.type() != CellSubgraphSpec::Input::TYPE_RECURRENT) continue;
const string &layer_name = cell_input.name();
input_recurrences_.emplace_back();
InputRecurrence &input = input_recurrences_.back();
const string name = MakeXlaInputRecurrentLayerName(layer_name);
size_t dimension = 1;
TF_RETURN_IF_ERROR(
manager.LookupLayer(name_, layer_name, &dimension, &input.handle));
TF_RETURN_IF_ERROR(LookupInputVector(name, xla::F32, dimension, instance,
&input.previous_output));
VLOG(1) << "Component '" << name_ << "' recurrence '" << layer_name
<< "': Added link to previous output";
}
return tensorflow::Status::OK();
}
tensorflow::Status XlaDynamicComponentBase::InitializeOutputLayers(
const CellSubgraphSpec &cell_subgraph_spec, NetworkStateManager *manager,
const tensorflow::XlaCompiledCpuFunction &instance) {
// Mapping from output tensor name to layer name, for detecting layer aliases.
std::map<string, string> tensor_to_layer;
for (const auto &cell_output : cell_subgraph_spec.output()) {
const string &layer_name = cell_output.name();
output_layers_.emplace_back();
OutputLayer &output = output_layers_.back();
const string name = MakeXlaOutputLayerName(layer_name);
// Add a new output layer or create an alias to an existing one.
if (tensor_to_layer.find(cell_output.tensor()) == tensor_to_layer.end()) {
TF_RETURN_IF_ERROR(
LookupOutputVector(name, xla::F32, -1, instance, &output.layer));
tensor_to_layer[cell_output.tensor()] = layer_name;
const size_t dimension = output.layer.elements;
TF_RETURN_IF_ERROR(
manager->AddLayer(layer_name, dimension, &output.handle));
VLOG(1) << "Component '" << name_ << "' output '" << layer_name
<< "': Added new layer";
} else {
const string &original_name = tensor_to_layer[cell_output.tensor()];
output_layers_.pop_back(); // not a "real" output
TF_RETURN_IF_ERROR(manager->AddLayerAlias(layer_name, original_name));
VLOG(1) << "Component '" << name_ << "' output '" << layer_name
<< "': Alias of '" << original_name << "'";
}
}
return tensorflow::Status::OK();
}
tensorflow::Status XlaDynamicComponentBase::InitializeConstantVectors() {
// Find the maximum recurrent layer dimension; the |zeros_| must be this big.
int max_dimension = 1; // ensure at least one element, for |zero_|
for (const InputRecurrence &input : input_recurrences_) {
max_dimension = std::max(max_dimension, input.previous_output.elements);
}
// Allocate the backing array and parcel it out into sub-views.
const std::vector<size_t> sizes = {sizeof(float),
max_dimension * sizeof(float)};
array_.Reset(ComputeTotalBytesWithAlignmentPadding(sizes));
memset(array_.view().data(), 0, array_.view().size()); // = 0.0 for float
std::vector<MutableAlignedView> views;
TF_RETURN_IF_ERROR(array_.view().Split(sizes, &views));
DCHECK_EQ(views.size(), 2);
// Promote to typed vectors.
one_ = Vector<float>(views[0]);
zero_ = Vector<float>(views[1], 1);
zeros_ = Vector<float>(views[1]);
DCHECK_EQ(zero_.size(), 1);
DCHECK_EQ(one_.size(), 1);
DCHECK_EQ(zeros_.size(), max_dimension);
// All memory was already zeroed, so only |one_| needs to be initialized.
MutableVector<float> mutable_one(views[0]);
mutable_one[0] = 1.0;
return tensorflow::Status::OK();
}
tensorflow::Status XlaDynamicComponentBase::MaybeInitializeLogits(
const ComponentSpec &component_spec, const NetworkStateManager &manager) {
// Logits are unnecessary when the component is deterministic.
deterministic_ = TransitionSystemTraits(component_spec).is_deterministic;
if (deterministic_) return tensorflow::Status::OK();
size_t dimension = 0;
TF_RETURN_IF_ERROR(
manager.LookupLayer(name_, kLogitsName, &dimension, &logits_handle_));
if (dimension != component_spec.num_actions()) {
return tensorflow::errors::InvalidArgument(
"Dimension mismatch between classification logits (", dimension,
") and ComponentSpec.num_actions (", component_spec.num_actions(),
") in component '", name_, "'");
}
return tensorflow::Status::OK();
}
tensorflow::Status XlaDynamicComponentBase::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
name_ = component_spec.name();
TF_RETURN_IF_ERROR(Validate(component_spec));
CellSubgraphSpec cell_subgraph_spec;
TF_RETURN_IF_ERROR(
InitializeFromComponentSpec(component_spec, &cell_subgraph_spec));
// Cache the XLA StaticData after InitializeFromComponentSpec().
static_data_ = &XlaStaticData();
// Make a temporary instance to determine shape and input/output indices.
tensorflow::XlaCompiledCpuFunction instance(
*static_data_, tensorflow::XlaCompiledCpuFunction::AllocMode::
RESULTS_PROFILES_AND_TEMPS_ONLY);
program_shape_ = instance.ProgramShape();
if (program_shape_ == nullptr) {
// Note: this fails when the proto dependency is missing.
return tensorflow::errors::InvalidArgument("XLA program shape missing");
}
VLOG(1) << "XLA program shape = " << program_shape_->DebugString();
// Configure the inputs and outputs of the XLA cell. As with NetworkUnit
// and NetworkUnitBase, output layers and input features must be initialized
// in a particular order to enable recurrent inputs. Specifically, we must
// populate output layers first, so they are available for recurrent access,
// both by the |input_recurrences_| and the |linked_embedding_manager_|.
TF_RETURN_IF_ERROR(InitializeOutputLayers(cell_subgraph_spec,
network_state_manager, instance));
TF_RETURN_IF_ERROR(fixed_embedding_manager_.Reset(
component_spec, variable_store, network_state_manager));
TF_RETURN_IF_ERROR(linked_embedding_manager_.Reset(
component_spec, variable_store, network_state_manager));
TF_RETURN_IF_ERROR(InitializeInputIds(instance));
TF_RETURN_IF_ERROR(InitializeInputLinks(instance));
TF_RETURN_IF_ERROR(InitializeInputRecurrences(
cell_subgraph_spec, *network_state_manager, instance));
TF_RETURN_IF_ERROR(InitializeConstantVectors());
TF_RETURN_IF_ERROR(
MaybeInitializeLogits(component_spec, *network_state_manager));
extension_manager->GetShared(&fixed_embeddings_handle_);
extension_manager->GetShared(&linked_embeddings_handle_);
extension_manager->AddLocal(&instance_handle_);
return tensorflow::Status::OK();
}
tensorflow::Status XlaDynamicComponentBase::Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const {
NetworkStates &network_states = session_state->network_states;
FixedEmbeddings &fixed_embeddings =
session_state->extensions.Get(fixed_embeddings_handle_);
LinkedEmbeddings &linked_embeddings =
session_state->extensions.Get(linked_embeddings_handle_);
tensorflow::XlaCompiledCpuFunction &instance = GetInstance(session_state);
for (size_t step_index = 0; !compute_session->IsTerminal(name());
++step_index) {
network_states.AddStep();
TF_RETURN_IF_ERROR(fixed_embeddings.Reset(&fixed_embedding_manager(),
network_states, compute_session));
TF_RETURN_IF_ERROR(linked_embeddings.Reset(
&linked_embedding_manager(), network_states, compute_session));
// Bind inputs into the |instance|.
BindInputIds(fixed_embeddings, &instance);
BindInputLinks(linked_embeddings, &instance);
BindInputRecurrences(step_index, network_states, &instance);
// Invoke the cell in the |instance|.
if (!instance.Run()) {
return tensorflow::errors::Internal("Error executing cell for ", name(),
": ", instance.error_msg());
}
// Realizes the binding: copy outputs out of the |instance|.
BindOutputLayers(step_index, network_states, &instance);
MaybeTrace(step_index, &instance, component_trace);
// If the component is deterministic, take the oracle transition instead of
// predicting the next transition using the logits.
if (deterministic()) {
compute_session->AdvanceFromOracle(name());
} else {
// AddStep() may invalidate the logits (due to reallocation), so the layer
// lookup cannot be hoisted out of this loop.
const Vector<float> logits(
network_states.GetLayer(logits_handle()).row(step_index));
if (!compute_session->AdvanceFromPrediction(
name(), logits.data(), kEvaluateNumItems, logits.size())) {
return tensorflow::errors::Internal(
"Error in ComputeSession::AdvanceFromPrediction()");
}
}
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_XLA_XLA_DYNAMIC_COMPONENT_BASE_H_
#define DRAGNN_RUNTIME_XLA_XLA_DYNAMIC_COMPONENT_BASE_H_
#include <stddef.h>
#include <string.h>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/fixed_embeddings.h"
#include "dragnn/runtime/linked_embeddings.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/type_keyed_set.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Base class for XLA-based versions of DynamicComponent.
//
// Roughly, this is a base class for a version of DynamicComponent where the
// per-transition-step computation is performed by a XLA cell instead of a
// NetworkUnit. This class implements Initialize() and Evaluate(). It has
// the most generality w.r.t. input features and links, but suffers from
// ComputeSession overhead. Subclasses which provide specialized logic that
// replaces the generic ComputeSession should override Evaluate().
//
// XLA JIT and AOT versions of this class must supply appropriate versions
// of InitializeFromComponentSpec() and XlaStaticData().
//
// At initialization time, this class creates lists of configuration structs
// that associate each input or output of the XLA cell with an operand that
// the DRAGNN runtime manages. See, e.g., InputId and InitializeInputIds().
//
// At inference time, subclasses can bind the relevant DRAGNN runtime operands
// to the inputs and outputs of the XLA instance (see, e.g., BindInputIds())
// and evaluate the XLA cell. Like DynamicComponent, the cell should be
// evaluated once per transition and the results used to advance the transition
// system state.
//
// Except as noted below, this is a drop-in replacement for DynamicComponent:
// * The name of the logits layer is hard-coded (see kLogitsName).
// * The fixed and linked channels must have embedding_dim=-1, because the fixed
// lookups and linked multiplications are handled within XLA.
//
// The XlaDynamicComponent subclass provides a general-purpose implementation
// of Evaluate(). Other subclasses provide optimized implementations subject to
// restrictions on the possible network configuration.
class XlaDynamicComponentBase : public Component {
public:
// Implements Component.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override;
protected:
// Initializes the XLA function using the |component_spec|. When successful,
// the relevant |cell_subgraph_spec| is filled in, and XlaStaticData() is safe
// to call. On error, returns non-OK.
virtual tensorflow::Status InitializeFromComponentSpec(
const ComponentSpec &component_spec,
CellSubgraphSpec *cell_subgraph_spec) = 0;
// Returns the StaticData that identifies a specific XLA compiled cell
// function. It is a fatal error to call this before a successful call to
// InitializeFromSpec().
virtual const tensorflow::XlaCompiledCpuFunction::StaticData &XlaStaticData()
const = 0;
private:
// Handle to one of the inputs. The |index| is into an array of
// pointers used by XlaCompiledCpuFunction. The input vector has
// the given number of |elements|.
struct InputHandle {
int index = -1;
int elements = 0;
};
// Handle to one of the outputs. This |index| is into an array of pointers
// into the results tuple used by XlaCompiledCpuFunction.
struct OutputHandle {
int index = -1;
int elements = 0;
int64 bytes = 0;
};
protected:
// Configuration for a fixed feature ID input.
struct InputId {
// Tensor to feed with the fixed feature ID.
InputHandle id;
};
// Configuration for a linked feature embedding input.
struct InputLink {
// Tensor to feed with the linked activation vector.
InputHandle activations;
// Tensor to feed with the linked out-of-bounds indicator, or -1 if the
// embedding does not need to be multiplied.
InputHandle out_of_bounds;
};
struct InputRecurrence {
// Handle of the output layer that is recurrently fed back.
LayerHandle<float> handle;
// Tensor to feed with the previous output activation vector.
InputHandle previous_output;
};
// Configuration for an output layer.
struct OutputLayer {
// Handle of the output layer.
LayerHandle<float> handle;
// Tensor that writes to the layer.
OutputHandle layer;
};
// Name of the layer containing logits. Unlike DynamicComponent, this class
// does not use the NetworkUnit abstraction and assumes that the logits will
// be stored in this layer.
// TODO(googleuser): Make this configurable, if needed. The logits layer could
// be given a special alias, for example.
static constexpr char kLogitsName[] = "logits";
// Points the cell input |handle| in the |instance| at the |vector|.
// Must be called before invoking the cell.
template <class T>
static void BindInput(Vector<T> vector, const InputHandle &handle,
tensorflow::XlaCompiledCpuFunction *instance);
// Copies the cell output |handle| in the |instance| to the |vector|.
// Must be called after invoking the cell.
//
// TODO(googleuser): Consider wrapping XlaCompiledCpuFunction along with a map
// from output indices to layer pointers, so this actually binds before the
// call to Run(). Then add a separate function that realizes the output
// binding, copying after Run().
template <class T>
static void BindOutput(MutableVector<T> vector, const OutputHandle &handle,
tensorflow::XlaCompiledCpuFunction *instance);
// Binds the feature IDs in the |fixed_embeddings| to the |instance| as
// configured by the |input_ids_|.
void BindInputIds(const FixedEmbeddings &fixed_embeddings,
tensorflow::XlaCompiledCpuFunction *instance) const;
// Binds the |embedding| and, if applicable, |is_out_of_bounds| to the
// |input_link| in the |instance|.
void BindInputLink(Vector<float> embedding, bool is_out_of_bounds,
const InputLink &input_link,
tensorflow::XlaCompiledCpuFunction *instance) const;
// Binds the activation vectors in the |linked_embeddings| to the |instance|
// as configured by the |input_links_|.
void BindInputLinks(const LinkedEmbeddings &linked_embeddings,
tensorflow::XlaCompiledCpuFunction *instance) const;
// Binds the output of the step before |step_index| in the |network_states| to
// the |instance| as configured by the |input_recurrences_|.
void BindInputRecurrences(size_t step_index,
const NetworkStates &network_states,
tensorflow::XlaCompiledCpuFunction *instance) const;
// Binds the output layers for the |step_index| in the |network_states| to the
// |instance| as configured by the |output_layers_|.
void BindOutputLayers(size_t step_index, const NetworkStates &network_states,
tensorflow::XlaCompiledCpuFunction *instance) const;
// Returns the reusable XLA instance in the |session_state|.
tensorflow::XlaCompiledCpuFunction &GetInstance(
SessionState *session_state) const;
// If |component_trace| is non-null, ensures that |step_index|+1 steps exist
// and traces the |instance| in the |step_index|'th step.
void MaybeTrace(size_t step_index,
tensorflow::XlaCompiledCpuFunction *instance,
ComponentTrace *component_trace) const;
// Accessors.
const string &name() const { return name_; }
const FixedEmbeddingManager &fixed_embedding_manager() const {
return fixed_embedding_manager_;
}
const LinkedEmbeddingManager &linked_embedding_manager() const {
return linked_embedding_manager_;
}
const std::vector<InputId> &input_ids() const { return input_ids_; }
const std::vector<InputLink> &input_links() const { return input_links_; }
const std::vector<InputRecurrence> &input_recurrences() const {
return input_recurrences_;
}
const std::vector<OutputLayer> &output_layers() const {
return output_layers_;
}
bool deterministic() const { return deterministic_; }
LayerHandle<float> logits_handle() const { return logits_handle_; }
private:
// Forbid batches and beams.
static constexpr int kEvaluateNumItems = 1;
// Required alignment of pointers to input tensors.
static constexpr size_t kXlaByteAlignment =
tensorflow::Allocator::kAllocatorAlignment;
// Returns non-OK if the |component_spec| specifies any unsupported settings.
// This includes both settings that are not yet implemented and those that are
// fundamentally incompatible with this class.
static tensorflow::Status Validate(const ComponentSpec &component_spec);
// Returns non-OK if the tensor called |name| isn't compatible with |type| or
// has an invalid |shape| given |dimension| for use as an input or output.
// If OK, |elements_out| contains the number of elements in the vector.
static tensorflow::Status ValidateTensor(const string &name,
const xla::PrimitiveType type,
int dimension,
const xla::Shape &shape,
int *elements_out);
// Points the |input_handle| or |output_handle| at the variable in the
// |network_| named |name|, which must have a vector-like shape (i.e., having
// at most one dimension > 1) and must match the |type|. The |instance| is
// used to determine the mapping from |name| to the handle. If the |dimension|
// is >= 0, then the |vector| must be the same size.
// On error, returns non-OK and sets |vector| to nullptr.
// Returns NOT_FOUND iff the |name| does not name a variable.
tensorflow::Status LookupInputVector(
const string &name, const xla::PrimitiveType type, int dimension,
const tensorflow::XlaCompiledCpuFunction &instance,
InputHandle *input_handle) const;
tensorflow::Status LookupOutputVector(
const string &name, const xla::PrimitiveType type, int dimension,
const tensorflow::XlaCompiledCpuFunction &instance,
OutputHandle *output_handle) const;
// Initializes the |input_ids_| based on the |fixed_embedding_manager_| and
// |network_|. On error, returns non-OK.
tensorflow::Status InitializeInputIds(
const tensorflow::XlaCompiledCpuFunction &instance);
// Initializes the |input_links_| based on the |linked_embedding_manager_| and
// |network_|. On error, returns non-OK.
tensorflow::Status InitializeInputLinks(
const tensorflow::XlaCompiledCpuFunction &instance);
// Initializes the |input_recurrences_| based on the |config|, |manager|, and
// |network_|. Requires that layers have been added to the |manager|. On
// error, returns non-OK.
tensorflow::Status InitializeInputRecurrences(
const CellSubgraphSpec &cell_subgraph_spec,
const NetworkStateManager &manager,
const tensorflow::XlaCompiledCpuFunction &instance);
// Initializes the |output_layers_| based on the |config|, |manager|, and
// |network_|. Adds layers to the |manager|. On error, returns non-OK.
tensorflow::Status InitializeOutputLayers(
const CellSubgraphSpec &cell_subgraph_spec, NetworkStateManager *manager,
const tensorflow::XlaCompiledCpuFunction &instance);
// Initializes the constant vectors (|zero_|, |one_|, and |zeros_|) and their
// backing |array_|. Requires that the |input_recurrences_| are initialized.
tensorflow::Status InitializeConstantVectors();
// Initializes the |logits_handle_| based on the |component_spec| and
// |manager|, if needed.
tensorflow::Status MaybeInitializeLogits(const ComponentSpec &component_spec,
const NetworkStateManager &manager);
// Name of this component.
string name_;
// Managers for the fixed and linked embeddings used by the component.
FixedEmbeddingManager fixed_embedding_manager_;
LinkedEmbeddingManager linked_embedding_manager_;
// Fixed and linked embeddings.
SharedExtensionHandle<FixedEmbeddings> fixed_embeddings_handle_;
SharedExtensionHandle<LinkedEmbeddings> linked_embeddings_handle_;
// The StaticData that identifies the XLA compiled function that implements
// the network cell. Cached to reduce virtual call overhead.
const tensorflow::XlaCompiledCpuFunction::StaticData *static_data_ = nullptr;
// Description of shapes and types of the compiled function, with indices that
// correspond to InputHandle and OutputHandle index values.
const xla::ProgramShape *program_shape_ = nullptr;
// List of fixed feature ID inputs, aligned with the relevant FixedEmbeddings.
std::vector<InputId> input_ids_;
// List of linked feature inputs, aligned with the relevant LinkedEmbeddings.
std::vector<InputLink> input_links_;
// List of recurrent input, not ordered.
std::vector<InputRecurrence> input_recurrences_;
// List of output layers, not ordered.
std::vector<OutputLayer> output_layers_;
// A few constant vectors and their backing array.
UniqueAlignedArray array_;
Vector<float> zero_; // [0.0], for linked out-of-bounds indicators
Vector<float> one_; // [1.0], for linked out-of-bounds indicators
Vector<float> zeros_; // [0.0...0.0], for linked activation vectors
// Whether the transition system is deterministic.
bool deterministic_ = false;
// Handle to the classification logits. Valid iff |deterministic_| is false.
LayerHandle<float> logits_handle_;
// Compiled function that implements the network cell. Local, since each
// component can have a different cell.
LocalExtensionHandle<tensorflow::XlaCompiledCpuFunction> instance_handle_;
};
// Implementation details below.
template <class T>
void XlaDynamicComponentBase::BindInput(
Vector<T> vector, const InputHandle &handle,
tensorflow::XlaCompiledCpuFunction *instance) {
DCHECK_GE(handle.index, 0);
DCHECK_EQ(reinterpret_cast<size_t>(vector.data()) % kXlaByteAlignment, 0);
// Since XLA only consumes non-const pointers, const_cast() is required.
// XLA will not modify the contents of the |vector|, provided it is bound
// to a cell input.
instance->set_arg_data(
handle.index,
const_cast<void *>(reinterpret_cast<const void *>(vector.data())));
}
template <class T>
void XlaDynamicComponentBase::BindOutput(
MutableVector<T> vector, const OutputHandle &handle,
tensorflow::XlaCompiledCpuFunction *instance) {
DCHECK_GE(handle.index, 0);
// XLA retains control over the allocation of outputs, and the pointer
// to the output must be determined using result_data() after every call
// to Run(). The outputs are copied into the session tensors.
std::memcpy(vector.data(), instance->result_data(handle.index), handle.bytes);
}
inline void XlaDynamicComponentBase::BindInputIds(
const FixedEmbeddings &fixed_embeddings,
tensorflow::XlaCompiledCpuFunction *instance) const {
for (size_t i = 0; i < input_ids_.size(); ++i) {
BindInput(fixed_embeddings.ids(i), input_ids_[i].id, instance);
}
}
inline void XlaDynamicComponentBase::BindInputLink(
Vector<float> embedding, bool is_out_of_bounds, const InputLink &input_link,
tensorflow::XlaCompiledCpuFunction *instance) const {
BindInput(embedding, input_link.activations, instance);
if (input_link.out_of_bounds.index != -1) {
BindInput(is_out_of_bounds ? one_ : zero_, input_link.out_of_bounds,
instance);
}
}
inline void XlaDynamicComponentBase::BindInputLinks(
const LinkedEmbeddings &linked_embeddings,
tensorflow::XlaCompiledCpuFunction *instance) const {
for (size_t i = 0; i < input_links_.size(); ++i) {
BindInputLink(linked_embeddings.embedding(i),
linked_embeddings.is_out_of_bounds(i), input_links_[i],
instance);
}
}
inline void XlaDynamicComponentBase::BindInputRecurrences(
size_t step_index, const NetworkStates &network_states,
tensorflow::XlaCompiledCpuFunction *instance) const {
for (const InputRecurrence &input : input_recurrences_) {
if (step_index == 0) {
// The previous output is out-of-bounds, so feed a zero vector. Recall
// that |zeros_| was constructed to be large enough for any recurrence.
BindInput(zeros_, input.previous_output, instance);
} else {
BindInput(Vector<float>(
network_states.GetLayer(input.handle).row(step_index - 1)),
input.previous_output, instance);
}
}
}
inline void XlaDynamicComponentBase::BindOutputLayers(
size_t step_index, const NetworkStates &network_states,
tensorflow::XlaCompiledCpuFunction *instance) const {
for (const OutputLayer &output : output_layers_) {
BindOutput(network_states.GetLayer(output.handle).row(step_index),
output.layer, instance);
}
}
inline tensorflow::XlaCompiledCpuFunction &XlaDynamicComponentBase::GetInstance(
SessionState *session_state) const {
return session_state->extensions.Get(
instance_handle_, *static_data_,
tensorflow::XlaCompiledCpuFunction::AllocMode::
RESULTS_PROFILES_AND_TEMPS_ONLY);
}
inline void XlaDynamicComponentBase::MaybeTrace(
size_t step_index, tensorflow::XlaCompiledCpuFunction * /*instance*/,
ComponentTrace *component_trace) const {
if (component_trace == nullptr) return;
while (component_trace->step_trace_size() <= step_index) {
component_trace->add_step_trace();
}
// TODO(googleuser): Add once the JIT API supports this.
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_XLA_XLA_DYNAMIC_COMPONENT_BASE_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <functional>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/cell_trace.pb.h"
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/fake_variable_store.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/type_keyed_set.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::_;
using ::testing::InSequence;
using ::testing::Invoke;
using ::testing::Return;
constexpr int kVocabularySize = 123;
constexpr int kLogitsDim = 11;
constexpr int kNumSteps = 50;
class XlaDynamicComponentTest : public NetworkTestBase {
protected:
// Options for building a GraphDef file for tests. By default, this specifies
// a working GraphDef file, but settings can be perturbed to trigger errors.
struct GraphDefOptions {
GraphDefOptions() = default;
// Dimension of the classification logits.
int logits_dim = kLogitsDim;
// Name of the variable containing the classification logits.
string logits_name = "logits";
// Type of the feature ID input.
xla::PrimitiveType id_type = xla::S32;
// Dimension of the feature ID input.
int id_dim = 1;
};
// Builds and writes a simple frozen GraphDef file. By default it produces a
// valid frozen GraphDef, but arguments can be overridden for error testing.
// Returns the path to the file.
static string WriteFrozenGraphDef() {
return WriteFrozenGraphDef(GraphDefOptions());
}
static tensorflow::DataType TensorFlowType(xla::PrimitiveType type) {
switch (type) {
case xla::S32:
return tensorflow::DT_INT32;
case xla::S64:
return tensorflow::DT_INT64;
case xla::F32:
return tensorflow::DT_FLOAT;
default:
break;
}
return tensorflow::DT_INVALID;
}
static string WriteFrozenGraphDef(const GraphDefOptions &options) {
CellSubgraphSpec spec;
tensorflow::GraphDef graph;
// A fixed feature ID input.
auto *input = spec.add_input();
input->set_name("fixed_channel_0_index_0_ids");
input->set_tensor("cell/id:0");
input->set_type(CellSubgraphSpec::Input::TYPE_FEATURE);
// The retrieved embedding row, as logits.
auto *output = spec.add_output();
output->set_name(options.logits_name);
output->set_tensor("cell/lookup:0");
// Add CellSubgraphSpec node.
tensorflow::Tensor spec_tensor(tensorflow::DT_STRING,
tensorflow::TensorShape({1}));
spec.SerializeToString(&spec_tensor.vec<string>()(0));
tensorflow::TensorProto spec_tensor_proto;
spec_tensor.AsProtoField(&spec_tensor_proto);
TF_CHECK_OK(
tensorflow::NodeDefBuilder(kFrozenCellSubgraphSpecNodeName, "Const")
.Attr("dtype", tensorflow::DT_STRING)
.Attr("value", spec_tensor_proto)
.Attr("shape", tensorflow::TensorShape({1}))
.Finalize(graph.add_node()));
// Fixed feature ID input placeholder node.
TF_CHECK_OK(tensorflow::NodeDefBuilder("cell/id", "Placeholder")
.Attr("dtype", TensorFlowType(options.id_type))
.Attr("shape", tensorflow::TensorShape({options.id_dim}))
.Finalize(graph.add_node()));
// An embedding matrix constant. Each embedding is filled with its index.
tensorflow::Tensor embeddings(
tensorflow::DT_FLOAT,
tensorflow::TensorShape({kVocabularySize, options.logits_dim}));
auto raw_tensor = embeddings.tensor<float, 2>();
for (int row = 0; row < kVocabularySize; ++row) {
for (int column = 0; column < options.logits_dim; ++column) {
raw_tensor(row, column) = row;
}
}
tensorflow::TensorProto embeddings_proto;
embeddings.AsProtoTensorContent(&embeddings_proto);
TF_CHECK_OK(tensorflow::NodeDefBuilder("cell/embedding_matrix", "Const")
.Attr("dtype", tensorflow::DT_FLOAT)
.Attr("value", embeddings_proto)
.Finalize(graph.add_node()));
// A Gather op that looks up the |id| in the |embeddings|, and returns the
// result in the |logits|.
TF_CHECK_OK(tensorflow::NodeDefBuilder("cell/lookup", "Gather")
.Input("cell/embedding_matrix", 0, tensorflow::DT_FLOAT)
.Input("cell/id", 0, TensorFlowType(options.id_type))
.Attr("validate_indices", true)
.Finalize(graph.add_node()));
const string path =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "graph-frozen");
TF_CHECK_OK(SaveFrozenGraphDef(path, graph));
return path;
}
// Creates a component, initializes it based on the |component_spec_text| and
// |flow_path|, and evaluates it. The |component_trace| is overwritten with
// traces, if non-null. On error, returns non-OK.
tensorflow::Status Run(const string &component_spec_text = "",
const string &flow_path = WriteFrozenGraphDef(),
ComponentTrace *component_trace = nullptr) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
if (!component_spec.has_num_actions()) {
component_spec.set_num_actions(kLogitsDim);
}
component_spec.set_name(kTestComponentName);
auto *fixed_feature = component_spec.add_fixed_feature();
fixed_feature->set_embedding_dim(-1);
fixed_feature->set_size(1);
TF_RETURN_IF_ERROR(AddFrozenGraphDefResource(flow_path, &component_spec));
AddComponent(kTestComponentName);
TF_RETURN_IF_ERROR(
Component::CreateOrError("XlaDynamicComponent", &component_));
TF_RETURN_IF_ERROR(component_->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
network_states_.Reset(&network_state_manager_);
StartComponent(0); // XlaDynamicComponent will add steps
session_state_.extensions.Reset(&extension_manager_);
TF_RETURN_IF_ERROR(component_->Evaluate(&session_state_, &compute_session_,
component_trace));
return tensorflow::Status::OK();
}
std::unique_ptr<Component> component_;
};
// Tests that XlaDynamicComponent fails if the spec uses attention.
TEST_F(XlaDynamicComponentTest, UnsupportedAttention) {
EXPECT_THAT(Run("attention_component:'foo'"),
test::IsErrorWithSubstr("Attention is not supported"));
}
// Tests that XlaDynamicComponent fails if the spec has embedded fixed
// features.
TEST_F(XlaDynamicComponentTest, InvalidFixedFeatureIsEmbedded) {
EXPECT_THAT(
Run("fixed_feature { embedding_dim:1 }"),
test::IsErrorWithSubstr("XLA requires non-embedded fixed features"));
}
// Tests that XlaDynamicComponent fails if the ComponentSpec has a fixed
// feature that does not appear in the graph.
TEST_F(XlaDynamicComponentTest, InvalidFixedFeatureNotInGraph) {
EXPECT_THAT(
Run("fixed_feature { embedding_dim:-1 size:1 }"),
test::IsErrorWithSubstr(tensorflow::strings::StrCat(
"No XLA tensor named '", MakeXlaInputFixedFeatureIdName(1, 0), "'")));
}
// Tests that XlaDynamicComponent fails if the spec has multipled linked
// features.
TEST_F(XlaDynamicComponentTest, InvalidLinkedFeatureIsMultiplied) {
EXPECT_THAT(
Run("linked_feature { embedding_dim:1 }"),
test::IsErrorWithSubstr("XLA requires non-multiplied linked features"));
}
// Tests that XlaDynamicComponent fails if the ComponentSpec has a linked
// feature that does not appear in the graph.
TEST_F(XlaDynamicComponentTest, InvalidLinkedFeatureNotInGraph) {
const string kSpec = tensorflow::strings::StrCat(
"linked_feature { source_component:'", kTestComponentName,
"' source_layer:'logits' embedding_dim:-1 size:1 }");
EXPECT_THAT(Run(kSpec), test::IsErrorWithSubstr(tensorflow::strings::StrCat(
"No XLA tensor named '",
MakeXlaInputLinkedActivationVectorName(0), "'")));
}
// Tests that XlaDynamicComponent fails if the GraphDef file does not exist.
TEST_F(XlaDynamicComponentTest, InvalidPath) {
EXPECT_THAT(Run("", "/invalid/path"),
test::IsErrorWithSubstr("No such file or directory"));
}
// Tests that XlaDynamicComponent fails if the logits dimension does not
// match ComponentSpec.num_actions.
TEST_F(XlaDynamicComponentTest, WrongLogitsDimension) {
GraphDefOptions options;
options.logits_dim = kLogitsDim + 1;
EXPECT_THAT(Run("", WriteFrozenGraphDef(options)),
test::IsErrorWithSubstr(
"Dimension mismatch between classification logits"));
}
// Tests that XlaDynamicComponent fails if there is no "logits" layer.
TEST_F(XlaDynamicComponentTest, WrongLogitsName) {
GraphDefOptions options;
options.logits_name = "not_logits";
EXPECT_THAT(Run("", WriteFrozenGraphDef(options)),
test::IsErrorWithSubstr("Unknown layer 'logits'"));
}
// Tests that XlaDynamicComponent fails to compile if one of the XLA
// tensors has the wrong type.
TEST_F(XlaDynamicComponentTest, FailToCompile) {
GraphDefOptions options;
options.id_type = xla::F32;
EXPECT_THAT(
Run("", WriteFrozenGraphDef(options)),
test::IsErrorWithSubstr("float is not in the list of allowed values"));
}
// Tests that XlaDynamicComponent fails if one of the XLA tensors is not
// vector-like.
TEST_F(XlaDynamicComponentTest, NotVectorLike) {
GraphDefOptions options;
options.id_dim = 2;
EXPECT_THAT(Run("", WriteFrozenGraphDef(options)),
test::IsErrorWithSubstr("XLA tensor has non-vector-like shape"));
}
// Tests that XlaDynamicComponent fails if AdvanceFromPrediction() fails.
TEST_F(XlaDynamicComponentTest, FailToAdvanceFromPrediction) {
EXPECT_CALL(compute_session_, IsTerminal(_)).WillRepeatedly(Return(false));
EXPECT_CALL(compute_session_, AdvanceFromPrediction(_, _, _, _))
.WillOnce(Return(false));
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{10, 1.0}})));
EXPECT_THAT(Run(), test::IsErrorWithSubstr(
"Error in ComputeSession::AdvanceFromPrediction()"));
}
// Tests that XlaDynamicComponent can run a simple non-deterministic frozen
// GraphDef.
TEST_F(XlaDynamicComponentTest, SimpleNonDeterministicFlow) {
SetupTransitionLoop(kNumSteps);
EXPECT_CALL(compute_session_, AdvanceFromPrediction(_, _, _, _))
.Times(kNumSteps)
.WillRepeatedly(Return(true));
{ // Extract a sequence of feature IDs equal to 2 * step_index.
ASSERT_LE(2 * kNumSteps, kVocabularySize);
InSequence scoped;
for (int step_index = 0; step_index < kNumSteps; ++step_index) {
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{2 * step_index, 1.0}})));
}
}
TF_ASSERT_OK(Run());
const Matrix<float> logits(GetLayer(kTestComponentName, "logits"));
ASSERT_EQ(logits.num_rows(), kNumSteps);
ASSERT_EQ(logits.num_columns(), kLogitsDim);
// Since each row of the embedding matrix is filled with its index, the logits
// should be equal to the feature IDs.
for (int step_index = 0; step_index < kNumSteps; ++step_index) {
ExpectVector(logits.row(step_index), kLogitsDim, 2 * step_index);
}
}
// Tests that XlaDynamicComponent can run a simple deterministic frozen
// GraphDef.
TEST_F(XlaDynamicComponentTest, SimpleDeterministicFlow) {
SetupTransitionLoop(kNumSteps);
EXPECT_CALL(compute_session_, AdvanceFromOracle(kTestComponentName))
.Times(kNumSteps);
{ // Extract a sequence of feature IDs equal to 2 * step_index.
ASSERT_LE(2 * kNumSteps, kVocabularySize);
InSequence scoped;
for (int step_index = 0; step_index < kNumSteps; ++step_index) {
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{2 * step_index, 1.0}})));
}
}
GraphDefOptions options;
options.logits_dim = 1;
TF_ASSERT_OK(Run("num_actions:1", WriteFrozenGraphDef(options)));
}
// Tests that XlaDynamicComponent can run a simple frozen GraphDef with tracing
// enabled.
TEST_F(XlaDynamicComponentTest, SimpleFlowWithTracing) {
SetupTransitionLoop(kNumSteps);
EXPECT_CALL(compute_session_, AdvanceFromPrediction(_, _, _, _))
.Times(kNumSteps)
.WillRepeatedly(Return(true));
{ // Extract a sequence of feature IDs equal to 2 * step_index.
ASSERT_LE(2 * kNumSteps, kVocabularySize);
InSequence scoped;
for (int step_index = 0; step_index < kNumSteps; ++step_index) {
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{2 * step_index, 1.0}})));
}
}
ComponentTrace component_trace;
TF_ASSERT_OK(Run("", WriteFrozenGraphDef(), &component_trace));
// Each step trace should have a cell trace from the XLA instance.
ASSERT_EQ(component_trace.step_trace_size(), kNumSteps);
for (const ComponentStepTrace &step_trace : component_trace.step_trace()) {
// TODO(googleuser): Add once the JIT API supports this.
EXPECT_EQ(step_trace.ExtensionSize(CellTrace::step_trace_extension), 0);
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Writes a file containing a text tf2xla::Config proto that is extracted
// from a frozen binary GraphDef file for a DRAGNN component.
//
// Usage: xla_extract_config input-graph-def output-config
// input-graph-def: input frozen tensorflow.GraphDef binary proto
// output-config: extracted tensorflow.tf2xla.Config text proto
#include <string.h>
#include "dragnn/protos/export.pb.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Writes the Config extracted from |input_graph_def| to |output_config|.
// On error, returns non-OK.
tensorflow::Status XlaExtractConfig(const char *input_graph_def,
const char *output_config) {
tensorflow::GraphDef graph;
TF_RETURN_IF_ERROR(LoadFrozenGraphDef(input_graph_def, &graph));
CellSubgraphSpec cell_subgraph_spec;
tensorflow::tf2xla::Config xla_config;
TF_RETURN_IF_ERROR(
GetSpecAndMakeXlaConfig(graph, &cell_subgraph_spec, &xla_config));
return WriteTextProto(tensorflow::Env::Default(), output_config, xla_config);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
int main(int argc, char **argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc != 3 || strlen(argv[1]) == 0 || strlen(argv[2]) == 0) {
LOG(FATAL)
<< "Usage: xla_extract_config input-graph-def output-config\n"
" input-graph-def: input frozen tensorflow.GraphDef binary proto\n"
" output-config: extracted tensorflow.tf2xla.Config text proto\n";
}
TF_CHECK_OK(syntaxnet::dragnn::runtime::XlaExtractConfig(argv[1], argv[2]));
return 0;
}
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Writes a Bazel file containing a definition for XLA_AOT_COMPONENTS. The
// value is an array; each element is an array of strings containing information
// needed to build the XLA AOT library for a graph, and the DRAGNN component
// that uses it.
//
// This file is loaded and then used by the dragnn_xla_aot_components() build
// rule (see xla_build_defs.bzl). Its contents are verified to be current by the
// dragnn_xla_aot_bazel_test() build rule, which runs this program.
//
// This program processes a set of MasterSpecs; the benefits for processing
// a set of MasterSpecs together are:
// - only a single build rule is necessary for adding component libraries;
// - duplicates of model/components across MasterSpecs are flagged as errors.
//
// Usage: xla_extract_names_from_specs graph-base [master-spec-path]+ bazel-path
// graph-base: base path to remove on GraphDefs in MasterSpecs
// master-specs: DRAGNN model MasterSpecs (includes base-path)
// bazel-path: Bazel definition output file
#include <string>
#include <vector>
#include "dragnn/runtime/xla/xla_spec_build_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
int main(int argc, char **argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc < 5) {
LOG(FATAL)
<< "Usage: xla_extract_names_from_specs"
" graph-base [master-spec-path]+ bazel-path\n"
" graph-base: base path to remove on GraphDefs in MasterSpecs\n"
" master-specs: DRAGNN model MasterSpecs (includes base-path)\n"
" bazel-path: Bazel definition output file\n";
}
const char *base_path = argv[1];
std::vector<string> master_spec_paths;
for (int i = 2; i < argc - 1; i++) {
master_spec_paths.push_back(argv[i]);
}
const string &bazel_path = argv[argc - 1];
string bazel_def;
tensorflow::strings::StrAppend(
&bazel_def,
"\"\"\"Generated by xla_extract_names_from_specs. "
"Do not edit.\"\"\"\n\n");
TF_CHECK_OK(syntaxnet::dragnn::runtime::MasterSpecsToBazelDef(
"XLA_AOT_COMPONENTS", base_path, master_spec_paths, &bazel_def));
TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
bazel_path, bazel_def));
return 0;
}
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include <cstddef>
#include <map>
#include <set>
#include <utility>
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
const char *const kFrozenCellSubgraphSpecNodeName = "CellSubgraphSpec";
namespace {
// Fills the TensorId fields given |tensor_name|. On error, returns non-OK.
tensorflow::Status FillXlaTensorId(const string &tensor_name,
tensorflow::tf2xla::TensorId *id) {
string name;
uint32 index;
TF_RETURN_IF_ERROR(ParseTensorName(tensor_name, &name, &index));
id->set_node_name(name);
id->set_output_index(index);
return tensorflow::Status::OK();
}
// Loads the |shape| proto from the placeholder |node|. On error, returns
// non-OK.
tensorflow::Status GetPlaceholderShape(
const tensorflow::NodeDef &node,
tensorflow::TensorShapeProto *shape_proto) {
if (node.op() != "Placeholder") {
return tensorflow::errors::InvalidArgument("Input node '", node.name(),
"' is not a Placeholder");
}
return tensorflow::GetNodeAttr(node, "shape", shape_proto);
}
} // namespace
tensorflow::Status LoadFrozenGraphDef(const string &frozen_graph_def_path,
tensorflow::GraphDef *graph_def) {
if (tensorflow::str_util::EndsWith(frozen_graph_def_path, ".pbtxt")) {
return tensorflow::ReadTextProto(tensorflow::Env::Default(),
frozen_graph_def_path, graph_def);
}
return tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
frozen_graph_def_path, graph_def);
}
tensorflow::Status SaveFrozenGraphDef(const string &frozen_graph_def_path,
const tensorflow::GraphDef &graph_def) {
const std::size_t size = graph_def.ByteSizeLong();
string data(size, '\0');
if (size > 0) {
tensorflow::protobuf::io::ArrayOutputStream array_stream(&data[0], size);
tensorflow::protobuf::io::CodedOutputStream output_stream(&array_stream);
output_stream.SetSerializationDeterministic(true);
graph_def.SerializeWithCachedSizes(&output_stream);
if (output_stream.HadError() || size != output_stream.ByteCount()) {
return tensorflow::errors::InvalidArgument("Cannot serialize GraphDef");
}
}
return tensorflow::WriteStringToFile(tensorflow::Env::Default(),
frozen_graph_def_path, data);
}
tensorflow::Status ParseTensorName(const string &tensor_name, string *name,
uint32 *index) {
if (tensor_name[0] == '^') {
return tensorflow::errors::InvalidArgument(
"Cannot parse name of control input '", tensor_name, "'");
}
const auto colon_index = tensor_name.rfind(':');
if (colon_index == string::npos) { // no colon; assume 0
*index = 0;
} else {
const string output_str = tensor_name.substr(colon_index + 1);
if (!tensorflow::strings::safe_strtou32(output_str, index)) {
return tensorflow::errors::InvalidArgument("Malformed tensor name ",
tensor_name);
}
}
// NB: If |colon_index| is string::npos, takes the whole string as desired.
*name = tensor_name.substr(0, colon_index);
return tensorflow::Status::OK();
}
tensorflow::Status GetSpecAndMakeXlaConfig(
const tensorflow::GraphDef &graph_def, CellSubgraphSpec *cell_subgraph_spec,
tensorflow::tf2xla::Config *xla_config) {
// Maps the node name to its corresponding node in the GraphDef.
std::map<string, const tensorflow::NodeDef *> node_name_map;
for (const tensorflow::NodeDef &node : graph_def.node()) {
node_name_map[node.name()] = &node;
}
// Looks for a node called |name| in |graph_def|. If present, returns OK
// and fills in |*node|, otherwise returns non-OK.
auto lookup_node = [&](const string &name, const tensorflow::NodeDef **node) {
const auto it = node_name_map.find(name);
if (it == node_name_map.end()) {
return tensorflow::errors::NotFound("Cannot find node ", name);
}
*node = it->second;
return tensorflow::Status::OK();
};
// Retrieves the CellSubgraphSpec from the frozen graph.
const tensorflow::NodeDef *spec_node = nullptr;
TF_RETURN_IF_ERROR(lookup_node("CellSubgraphSpec", &spec_node));
const auto value_it = spec_node->attr().find("value");
if (value_it == spec_node->attr().end()) {
return tensorflow::errors::NotFound("Cannot find CellSubgraphSpec value");
}
if (!cell_subgraph_spec->ParseFromString(
value_it->second.tensor().string_val(0))) {
return tensorflow::errors::InvalidArgument(
"Failed to parse CellSubgraphSpec");
}
VLOG(1) << "CellSubgraphSpec: " << cell_subgraph_spec->DebugString();
// Builds the Config feeds.
for (const auto &input : cell_subgraph_spec->input()) {
auto *feed = xla_config->add_feed();
feed->set_name(MakeXlaInputLayerName(input.name()));
TF_RETURN_IF_ERROR(FillXlaTensorId(input.tensor(), feed->mutable_id()));
const tensorflow::NodeDef *input_node;
TF_RETURN_IF_ERROR(lookup_node(feed->id().node_name(), &input_node));
TF_RETURN_IF_ERROR(GetPlaceholderShape(*input_node, feed->mutable_shape()));
}
// Builds the Config fetches and alias map.
std::set<string> output_tensors;
for (const auto &output : cell_subgraph_spec->output()) {
if (output_tensors.insert(output.tensor()).second) {
// The first time a tensor is encountered, this adds a fetch along with
// its name. The remaining names associated with the same tensor (aliases)
// are handled by InitializeOutputLayers.
auto *fetch = xla_config->add_fetch();
fetch->set_name(MakeXlaOutputLayerName(output.name()));
TF_RETURN_IF_ERROR(FillXlaTensorId(output.tensor(), fetch->mutable_id()));
}
}
VLOG(1) << "Config: " << xla_config->DebugString();
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for working with specifications of XLA-based DRAGNN runtime models.
#ifndef DRAGNN_RUNTIME_XLA_XLA_GRAPH_UTILS_H_
#define DRAGNN_RUNTIME_XLA_XLA_GRAPH_UTILS_H_
#include <string>
#include "dragnn/protos/export.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// The name of the node in the frozen GraphSpec (for a particular component)
// that contains the serialized CellSubgraphSpec.
extern const char *const kFrozenCellSubgraphSpecNodeName;
// Loads a GraphDef file from the |frozen_graph_def_path| into the |graph_def|.
// Assumes binary proto unless |frozen_graph_def_path| ends with ".pbtxt", in
// which case it assumes text proto format. On error, returns non-OK.
tensorflow::Status LoadFrozenGraphDef(const string &frozen_graph_def_path,
tensorflow::GraphDef *graph_def);
// Saves a GraphDef |graph_def| in the file |frozen_graph_def_path|. Uses
// deterministic serialization to avoid churn due to attr map order.
// Always writes in binary format. On error, returns non-OK.
tensorflow::Status SaveFrozenGraphDef(const string &frozen_graph_def_path,
const tensorflow::GraphDef &graph_def);
// Fills in |name| and |index| given the |tensor_name| of the form
// "name" or "name:index". On error, changes nothing and returns non-OK.
tensorflow::Status ParseTensorName(const string &tensor_name, string *name,
uint32 *index);
// Given a frozen |graph_def|, extracts the |cell_subgraph_spec| stored within
// it, and generates the |xla_config| proto. Whenever an output tensor is
// aliased, the output in |xla_config| is taken the first occurrence of the
// tensor in |cell_subgraph_spec| (aliases are resolved in the XLA component
// in InitializeOutputLayers). On error, returns non-OK.
tensorflow::Status GetSpecAndMakeXlaConfig(
const tensorflow::GraphDef &graph_def, CellSubgraphSpec *cell_subgraph_spec,
tensorflow::tf2xla::Config *xla_config);
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_XLA_XLA_GRAPH_UTILS_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment