Commit edea2b67 authored by Terry Koo's avatar Terry Koo
Browse files

Remove runtime because reasons.

parent a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/variable_store.h"
#include <stddef.h>
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/test/fake_variable_store.h"
#include "dragnn/runtime/test/helpers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Tests that VariableStore::Lookup() fails to retrieve a vector if the
// underlying area does not have exactly one sub-view.
TEST(VariableStoreTest, LookupEmptyVector) {
SimpleFakeVariableStore store;
Vector<uint32> vector32;
store.MockLookup<uint32>({0}, {});
EXPECT_THAT(store.Lookup("empty", &vector32),
test::IsErrorWithSubstr(
"Vector variable 'empty' should have 1 sub-view but has 0"));
}
TEST(VariableStoreTest, LookupVectorWrongDimensions) {
SimpleFakeVariableStore store;
Vector<float> vector;
// Dimensions should indicate number of logical elements (1), not bytes (4).
store.MockLookup<char>({4}, {{'1', '2', '3', '4'}});
EXPECT_THAT(store.Lookup("wrongdim_1", &vector),
test::IsErrorWithSubstr(
"Vector size (1) disagrees with dimensions[0] (4)"));
// Missing dimensions raise errors.
store.MockLookup<char>({}, {{'1', '2', '3', '4'}});
EXPECT_THAT(store.Lookup("nodims", &vector),
test::IsErrorWithSubstr("Expected 1 dimensions, got 0"));
}
// Tests that VariableStore::Lookup() fails to retrieve a vector if the
// underlying area is not divisible into elements of sizeof(T) bytes.
TEST(VariableStoreTest, LookupVector) {
SimpleFakeVariableStore store;
Vector<uint32> vector32;
Vector<uint64> vector64;
store.MockLookup<char>({6}, {{'1', '2', '3', '4', '5', '6'}});
EXPECT_THAT(
store.Lookup("123456", &vector32),
test::IsErrorWithSubstr(
"Vector variable '123456' does not divide into elements of size 4"));
store.MockLookup<char>({6}, {{'1', '2', '3', '4', '5', '6'}});
EXPECT_THAT(
store.Lookup("123456", &vector64),
test::IsErrorWithSubstr(
"Vector variable '123456' does not divide into elements of size 8"));
store.MockLookup<char>({2}, {{'1', '2', '3', '4', '5', '6', '7', '8'}});
TF_EXPECT_OK(store.Lookup("12345678", &vector32));
EXPECT_EQ(vector32.size(), 2);
const string bytes32(reinterpret_cast<const char *>(vector32.data()), 8);
EXPECT_EQ(bytes32, "12345678");
store.MockLookup<uint64>({1}, {{7777}});
TF_EXPECT_OK(store.Lookup("12345678", &vector64));
EXPECT_EQ(vector64.size(), 1);
EXPECT_EQ(vector64[0], 7777);
}
// Tests that the VariableStore fails to lookup a matrix if its dimensions are
// mismatched.
TEST(VariableStoreTest, LookupMatrixWrongDimensions) {
SimpleFakeVariableStore store;
Matrix<float> matrix;
// Missing dimensions raise errors.
store.MockLookup<char>({}, {{'1', '2', '3', '4'}});
EXPECT_THAT(store.Lookup("nodims", &matrix),
test::IsErrorWithSubstr("Expected 2 dimensions, got 0"));
// Wrong number of columns returned.
store.MockLookup<char>({1, 2}, {{'1', '2', '3', '4'}});
EXPECT_THAT(store.Lookup("wrongcols", &matrix),
test::IsErrorWithSubstr(
"Matrix columns (1) disagrees with dimensions[1] (2)"));
// Wrong number of rows returned.
store.MockLookup<char>({3, 1}, {{'1', '2', '3', '4'}});
EXPECT_THAT(store.Lookup("wrongrows", &matrix),
test::IsErrorWithSubstr(
"Matrix rows (1) disagrees with dimensions[0] (3)"));
}
// Tests that VariableStore::Lookup() fails to retrieve a row-major matrix if
// the underlying area is not divisible into elements of sizeof(T) bytes.
TEST(VariableStoreTest, LookupRowMajorMatrix) {
SimpleFakeVariableStore store;
Matrix<uint32> matrix32;
Matrix<uint64> matrix64;
store.MockLookup<char>(
{6, 2}, ReplicateRows<char>({'1', '2', '3', '4', '5', '6'}, 6));
EXPECT_THAT(
store.Lookup("123456", &matrix32),
test::IsErrorWithSubstr(
"Matrix variable '123456' does not divide into elements of size 4"));
store.MockLookup<char>(
{6, 2}, ReplicateRows<char>({'1', '2', '3', '4', '5', '6'}, 6));
EXPECT_THAT(
store.Lookup("123456", &matrix64),
test::IsErrorWithSubstr(
"Matrix variable '123456' does not divide into elements of size 8"));
store.MockLookup<char>(
{8, 2}, ReplicateRows<char>({'1', '2', '3', '4', '5', '6', '7', '8'}, 8));
TF_EXPECT_OK(store.Lookup("12345678", &matrix32));
EXPECT_EQ(matrix32.num_rows(), 8);
EXPECT_EQ(matrix32.num_columns(), 2);
for (size_t i = 0; i < matrix32.num_rows(); ++i) {
const string bytes32(reinterpret_cast<const char *>(matrix32.row(i).data()),
8);
EXPECT_EQ(bytes32, "12345678");
}
store.MockLookup({8, 1}, ReplicateRows<uint64>({7777}, 8));
TF_EXPECT_OK(store.Lookup("12345678", &matrix64));
EXPECT_EQ(matrix64.num_rows(), 8);
EXPECT_EQ(matrix64.num_columns(), 1);
for (size_t i = 0; i < matrix64.num_rows(); ++i) {
EXPECT_EQ(matrix64.row(i)[0], 7777);
}
}
// Tests that the VariableStore fails to lookup a blocked matrix if its
// dimensions are mismatched.
TEST(VariableStoreTest, BlockedLookupWrongDimensions) {
SimpleFakeVariableStore store;
BlockedMatrix<float> matrix;
// Missing dimensions raise errors.
store.MockLookup<char>({}, {{'1', '2', '3', '4'}});
EXPECT_THAT(store.Lookup("nodims", &matrix),
test::IsErrorWithSubstr("Expected 3 dimensions, got 0"));
// Wrong number of columns returned.
store.MockLookup<char>({1, 2, 1}, {{'1', '2', '3', '4'}});
EXPECT_THAT(store.Lookup("wrongcols", &matrix),
test::IsErrorWithSubstr("Rows * cols (2) != area view size (1)"));
// Wrong number of rows returned.
store.MockLookup<char>({3, 1, 1}, {{'1', '2', '3', '4'}});
EXPECT_THAT(store.Lookup("wrongrows", &matrix),
test::IsErrorWithSubstr("Rows * cols (3) != area view size (1)"));
// Wrong area view size.
store.MockLookup<float>({1, 1, 1}, {{1.0f, 2.0f}});
EXPECT_THAT(
store.Lookup("wrongviewsize", &matrix),
test::IsErrorWithSubstr("Area view size (8) doesn't correspond to block "
"size (1) times data type size (4)"));
}
TEST(VariableStoreTest, DoubleBlockedLookup) {
// BlockedMatrix::Reset() will fail if there is any alignment padding, so we
// construct an appropriate block size.
static_assert(internal::kAlignmentBytes % sizeof(double) == 0,
"Alignment requirement is too small");
constexpr int kBlockSize = internal::kAlignmentBytes / sizeof(double);
constexpr int kNumSubMatrices = 3;
constexpr int kNumRows = 10;
constexpr int kNumColumns = kNumSubMatrices * kBlockSize;
constexpr int kNumBlocks = kNumSubMatrices * kNumRows;
// Fill a data matrix with consecutively increasing values.
std::vector<std::vector<double>> data;
double value = 0.0;
for (int block = 0; block < kNumBlocks; ++block) {
data.emplace_back();
for (int i = 0; i < kBlockSize; ++i) data.back().push_back(value++);
}
SimpleFakeVariableStore store;
BlockedMatrix<double> matrix;
store.MockLookup<double>({kNumRows, kNumColumns, kBlockSize}, data);
TF_EXPECT_OK(store.Lookup("small_matrix_lookup", &matrix));
EXPECT_EQ(matrix.num_rows(), kNumRows);
EXPECT_EQ(matrix.num_columns(), kNumColumns);
EXPECT_EQ(matrix.block_size(), kBlockSize);
EXPECT_EQ(matrix.num_vectors(), kNumBlocks);
double expected = 0.0;
for (int i = 0; i < kNumBlocks; ++i) {
for (double value : matrix.vector(i)) EXPECT_EQ(value, expected++);
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/variable_store_wrappers.h"
#include <algorithm>
#include <tuple>
#include <utility>
#include <vector>
#include "dragnn/runtime/flexible_matrix_kernel.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns the name of the averaged version of the variable named |name|.
string GetAveragedName(const string &name) {
return tensorflow::strings::StrCat(name, "/ExponentialMovingAverage");
}
// Rounds a number, |rows|, up to a multiple of |multiple|. For example,
// PadRows(6, 4) will return 8, because 8 is the nearest number after 6 that is
// divisible by 4. This method requires that |multiple| be positive. It is used
// for pre-calculating the dimension of a blocked matrix, instead of having to
// read the entire matrix.
int PadRows(int rows, int multiple) {
DCHECK_GT(multiple, 0);
return multiple * ((rows + multiple - 1) / multiple);
}
// Calculates effective speed of a blocked matrix kernel. Blocked kernels may do
// a bit more calculation than necessary (since each AVX/SSE register contains
// multiple values), so their effective speed is less in those cases.
float EffectiveGflops(int rows, int block_dim, float base_gflops) {
float padded_rows = PadRows(rows, block_dim);
return (rows / padded_rows) * base_gflops;
}
} // namespace
TryAveragedVariableStoreWrapper::TryAveragedVariableStoreWrapper(
std::unique_ptr<VariableStore> variable_store, bool allow_fallback)
: wrapped_variable_store_(std::move(variable_store)),
allow_fallback_(allow_fallback) {}
tensorflow::Status TryAveragedVariableStoreWrapper::Lookup(
const string &name, VariableSpec::Format format,
std::vector<size_t> *dimensions, AlignedArea *area) {
tensorflow::Status status = wrapped_variable_store_->Lookup(
GetAveragedName(name), format, dimensions, area);
if (status.ok()) {
LOG(INFO) << "Using averaged variable: " << GetAveragedName(name);
return status;
}
if (allow_fallback_) {
LOG(INFO) << "Falling back to non-averaged variable: " << name;
return wrapped_variable_store_->Lookup(name, format, dimensions, area);
}
return tensorflow::errors::InvalidArgument(
"Failed to retrieve averaged variable '", GetAveragedName(name),
"' for variable '", name, "': ", status.error_message());
}
tensorflow::Status TryAveragedVariableStoreWrapper::Close() {
return wrapped_variable_store_->Close();
}
CaptureUsedVariableStoreWrapper::CaptureUsedVariableStoreWrapper(
std::unique_ptr<VariableStore> variable_store)
: wrapped_variable_store_(std::move(variable_store)) {}
tensorflow::Status CaptureUsedVariableStoreWrapper::Lookup(
const string &name, VariableSpec::Format format,
std::vector<size_t> *dimensions, AlignedArea *area) {
tensorflow::Status status =
wrapped_variable_store_->Lookup(name, format, dimensions, area);
if (status.ok()) {
// Capture the variable if the wrapped store's Lookup() succeeds.
VariableKey key(name, format);
std::pair<VariableKey, VariableValue> value(
key, VariableValue(*dimensions, *area));
if (index_.find(key) != index_.end()) {
variables_[index_[key]] = value;
} else {
index_[key] = variables_.size();
variables_.push_back(value);
}
}
return status;
}
tensorflow::Status CaptureUsedVariableStoreWrapper::Close() {
return wrapped_variable_store_->Close();
}
FlexibleMatrixVariableStoreWrapper::FlexibleMatrixVariableStoreWrapper(
std::unique_ptr<VariableStore> variable_store)
: wrapped_variable_store_(std::move(variable_store)) {}
tensorflow::Status FlexibleMatrixVariableStoreWrapper::Lookup(
const string &name, VariableSpec::Format format,
std::vector<size_t> *dimensions, AlignedArea *area) {
// Forward requests that don't match the relevant suffix.
tensorflow::StringPiece name_piece = name;
if (!tensorflow::str_util::ConsumeSuffix(&name_piece,
FlexibleMatrixKernel::kSuffix)) {
return wrapped_variable_store_->Lookup(name, format, dimensions, area);
}
const string basename = name_piece.ToString();
// Fetch the non-blocked, non-transposed version of the matrix. This wrapper
// will be nested inside the capturing wrapper, so we can do multiple lookups
// without capturing more variables than we need.
Matrix<float> plain_matrix;
TF_RETURN_IF_ERROR(wrapped_variable_store_->Lookup(basename, &plain_matrix));
const int output_dimension = plain_matrix.num_columns();
// Performance estimates for different methods. A mix of 32/48 blocked
// matrices got 28 GFLOPS, whereas only unblocked got 2.8 GFLOPS.
using Candidate = std::tuple<float, VariableSpec::Format, string>;
const std::vector<Candidate> candidates = {
Candidate(2.8f, VariableSpec::FORMAT_ROW_MAJOR_MATRIX,
tensorflow::strings::StrCat(basename, "/transposed")),
Candidate(EffectiveGflops(output_dimension, 32, 25.0f),
VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX,
tensorflow::strings::StrCat(basename, "/matrix/blocked32")),
Candidate(EffectiveGflops(output_dimension, 48, 25.0f),
VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX,
tensorflow::strings::StrCat(basename, "/matrix/blocked48"))};
const auto max_it = std::max_element(candidates.begin(), candidates.end());
const VariableSpec::Format argmax_format = std::get<1>(*max_it);
const string &argmax_name = std::get<2>(*max_it);
// The requested |format| must match the best format. If not, return error
// and wait until the proper format is requested.
if (format != argmax_format) {
return tensorflow::errors::FailedPrecondition(
"Sub-optimal matrix format: ", VariableSpec::Format_Name(format), " (",
VariableSpec::Format_Name(argmax_format), " is best)");
}
return wrapped_variable_store_->Lookup(argmax_name, format, dimensions, area);
}
tensorflow::Status FlexibleMatrixVariableStoreWrapper::Close() {
return wrapped_variable_store_->Close();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// A set of VariableStore wrappers that provide compositional functionality.
// These are intended for offline processing and experimentation; avoid using
// these in production, where ArrayVariableStore and its subclasses should be
// used instead.
#ifndef DRAGNN_RUNTIME_VARIABLE_STORE_WRAPPERS_H_
#define DRAGNN_RUNTIME_VARIABLE_STORE_WRAPPERS_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A wrapper that looks for an averaged version of each variable in the wrapped
// store, and failing that optionally falls back to the non-averaged version.
class TryAveragedVariableStoreWrapper : public VariableStore {
public:
// Wraps the |variable_store|. If |allow_fallback| is true, then when the
// averaged version is missing the non-averaged version can be substituted.
explicit TryAveragedVariableStoreWrapper(
std::unique_ptr<VariableStore> variable_store,
bool allow_fallback = false);
// Implements VariableStore.
using VariableStore::Lookup; // import Lookup<T>() convenience methods
tensorflow::Status Lookup(const string &name, VariableSpec::Format format,
std::vector<size_t> *dimensions,
AlignedArea *area) override;
tensorflow::Status Close() override;
private:
// Wrapped variable store.
const std::unique_ptr<VariableStore> wrapped_variable_store_;
// Whether to allow fallback to the non-averaged variable.
const bool allow_fallback_;
};
// A wrapper that captures each successfully retrieved variable. Useful for
// finding the exact set of variables used by some set of DRAGNN components.
class CaptureUsedVariableStoreWrapper : public VariableStore {
public:
// `Variables` is a list of captured variables, in order that they are
// captured. We want to preserve the order, so that arrays are sequential in
// memory. `VariableKey` is name/format metadata used to uniquely identify
// a variable; duplicate lookups to the same variable will not capture it
// twice, and its position in the list will be the first position.
using VariableKey = std::pair<string, VariableSpec::Format>;
using VariableValue = std::pair<std::vector<size_t>, AlignedArea>;
using Variables = std::vector<std::pair<VariableKey, VariableValue>>;
// Wraps the |variable_store|.
explicit CaptureUsedVariableStoreWrapper(
std::unique_ptr<VariableStore> variable_store);
// Implements VariableStore.
using VariableStore::Lookup; // import Lookup<T>() convenience methods
tensorflow::Status Lookup(const string &name, VariableSpec::Format format,
std::vector<size_t> *dimensions,
AlignedArea *area) override;
tensorflow::Status Close() override;
// Returns the current set of captured variables. The variable content in the
// returned mapping is valid while this lives.
const Variables &variables() const { return variables_; }
private:
// Wrapped variable store.
const std::unique_ptr<VariableStore> wrapped_variable_store_;
// Current set of captured variables.
Variables variables_;
// Indexes key --> position in variables_ list.
std::map<VariableKey, int> index_;
};
// A wrapper that selects a matrix format for the FlexibleMatrixKernel. This
// could be done in the FlexibleMatrixKernel itself, but factoring it into this
// wrapper allows the selection to occur at model construction time instead of
// at model loading time.
class FlexibleMatrixVariableStoreWrapper : public VariableStore {
public:
// Wraps the |variable_store|.
explicit FlexibleMatrixVariableStoreWrapper(
std::unique_ptr<VariableStore> variable_store);
// Looks up the variable named |name| with format |format|, returning its
// shape in |dimensions| and its data in |area|. On error, returns non-OK.
//
// If the |name| does not end in FlexibleMatrixKernel::kSuffix, passes the
// request along to the |wrapped_variable_store_|. Otherwise, if |name| is
// "foo/<kSuffix>", estimates the throughput of the matrix "foo" in various
// formats (assuming the workload is matrix-vector multiplications), selects
// the fastest format, and returns the matrix in that format.
//
// It is an error if the selected matrix format does not match the requested
// variable |format| (e.g., non-blocked vs blocked). The FlexibleMatrixKernel
// should request the variable in all relevant variable formats, so eventually
// it will issue a request in a matching format.
tensorflow::Status Lookup(const string &name, VariableSpec::Format format,
std::vector<size_t> *dimensions,
AlignedArea *area) override;
using VariableStore::Lookup; // import Lookup<T>() convenience methods
// Implements VariableStore.
tensorflow::Status Close() override;
private:
// Wrapped variable store.
const std::unique_ptr<VariableStore> wrapped_variable_store_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_VARIABLE_STORE_WRAPPERS_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/variable_store_wrappers.h"
#include <stddef.h>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/flexible_matrix_kernel.h"
#include "dragnn/runtime/math/transformations.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/test/fake_variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns a variable store with some default entries for tests. Specifically,
// "foo" has an averaged version while "bar" does not.
std::unique_ptr<VariableStore> NewVariableStore() {
std::unique_ptr<FakeVariableStore> store(new FakeVariableStore());
store->AddOrDie("foo", {{1.0, 2.0}, //
{3.0, 4.0}});
store->AddOrDie("foo/ExponentialMovingAverage", {{10.0, 20.0}, //
{30.0, 40.0}});
store->AddOrDie("bar", {{10.0, 9.0, 8.0}, //
{7.0, 6.0, 5.0}});
return std::move(store);
}
// Expects that the |vector| contains the |data|.
template <typename T>
void ExpectVector(Vector<T> vector, const std::vector<T> &data) {
ASSERT_EQ(vector.size(), data.size());
for (size_t i = 0; i < data.size(); ++i) EXPECT_EQ(vector[i], data[i]);
}
// Expects that the |matrix| contains the |data|.
void ExpectMatrix(Matrix<float> matrix,
const std::vector<std::vector<float>> &data) {
ASSERT_EQ(matrix.num_rows(), data.size());
if (data.empty()) return;
ASSERT_EQ(matrix.num_columns(), data[0].size());
for (size_t i = 0; i < data.size(); ++i) ExpectVector(matrix.row(i), data[i]);
}
// Tests that the averaging wrapper uses the averaged version of a variable if
// available, the non-averaged version failing that, and errors out otherwise.
TEST(TryAveragedVariableStoreWrapperTest, FallbackAllowed) {
TryAveragedVariableStoreWrapper store(NewVariableStore(),
/*allow_fallback=*/true);
Matrix<float> foo_averaged;
Matrix<float> bar_non_averaged;
Matrix<float> unused_matrix;
TF_ASSERT_OK(store.Lookup("foo", &foo_averaged));
TF_ASSERT_OK(store.Lookup("bar", &bar_non_averaged));
EXPECT_THAT(store.Lookup("missing", &unused_matrix),
test::IsErrorWithSubstr("Unknown variable"));
TF_EXPECT_OK(store.Close());
ExpectMatrix(foo_averaged, {{10.0, 20.0}, //
{30.0, 40.0}});
ExpectMatrix(bar_non_averaged, {{10.0, 9.0, 8.0}, //
{7.0, 6.0, 5.0}});
}
// As above, but with fallback disabled (the default behavior).
TEST(TryAveragedVariableStoreWrapperTest, FallbackForbidden) {
TryAveragedVariableStoreWrapper store(NewVariableStore());
Matrix<float> foo_averaged;
Matrix<float> bar_non_averaged;
Matrix<float> unused_matrix;
TF_ASSERT_OK(store.Lookup("foo", &foo_averaged));
EXPECT_THAT(store.Lookup("bar", &bar_non_averaged),
test::IsErrorWithSubstr("Failed to retrieve averaged variable "
"'bar/ExponentialMovingAverage' for "
"variable 'bar'"));
EXPECT_THAT(store.Lookup("missing", &unused_matrix),
test::IsErrorWithSubstr("Failed to retrieve averaged variable "
"'missing/ExponentialMovingAverage' for "
"variable 'missing'"));
TF_EXPECT_OK(store.Close());
ExpectMatrix(foo_averaged, {{10.0, 20.0}, //
{30.0, 40.0}});
}
// Tests that the capturing wrapper correctly records the set of variables that
// have been looked up.
TEST(CaptureUsedVariableStoreWrapperTest, Capturing) {
CaptureUsedVariableStoreWrapper store(NewVariableStore());
Vector<float> unused_vector;
Matrix<float> unused_row_major_matrix;
// Try a completely missing variable. As a failed lookup, this should not
// appear among the captured variables.
EXPECT_THAT(store.Lookup("missing", &unused_vector),
test::IsErrorWithSubstr("Unknown variable"));
// Look up one variable of each type.
TF_ASSERT_OK(store.Lookup("foo", &unused_vector));
TF_ASSERT_OK(store.Lookup("bar", &unused_row_major_matrix));
TF_EXPECT_OK(store.Close());
// Check the names and formats of the captured variables.
const auto &variables = store.variables();
ASSERT_EQ(variables.size(), 2);
// The variables must be returned in order. Check their names and format
// first.
EXPECT_EQ(variables[0].first.first, "foo");
EXPECT_EQ(variables[0].first.second, VariableSpec::FORMAT_FLAT);
EXPECT_EQ(variables[1].first.first, "bar");
EXPECT_EQ(variables[1].first.second, VariableSpec::FORMAT_ROW_MAJOR_MATRIX);
// Check the content of 'foo'.
EXPECT_EQ(variables[0].second.first, std::vector<size_t>{4});
ExpectVector(Vector<float>(variables[0].second.second.view(0)),
{1.0, 2.0, 3.0, 4.0});
// Check the content of 'bar'.
EXPECT_EQ(variables[1].second.first, std::vector<size_t>({2, 3}));
ExpectMatrix(Matrix<float>(variables[1].second.second), {{10.0, 9.0, 8.0}, //
{7.0, 6.0, 5.0}});
}
// Returns a variable store with some blocked and transposed matrices, for
// testing the flexible matrix wrapper.
std::unique_ptr<VariableStore> NewBlockedAndTransposedStore() {
std::unique_ptr<FakeVariableStore> store(new FakeVariableStore());
// A tiny matrix, which favors the non-blocked format.
store->AddOrDie("1x1", {{1.0}});
store->AddOrDie("1x1/transposed", {{1.0}});
store->AddOrDie("1x1/matrix/blocked32", {{1.0}});
store->AddOrDie("1x1/matrix/blocked48", {{1.0}});
// A matrix that is a multiple of 32, which should favor block size 32.
const std::vector<float> row32(32, 32.0);
const std::vector<std::vector<float>> data32(16, row32);
store->AddOrDie("16x32", data32);
store->AddOrDie("16x32/transposed", data32);
store->AddOrDie("16x32/matrix/blocked32", data32);
store->AddOrDie("16x32/matrix/blocked48", data32);
// A matrix that is a multiple of 48, which should favor block size 48.
const std::vector<float> row48(48, 48.0);
const std::vector<std::vector<float>> data48(24, row48);
store->AddOrDie("24x48", data48);
store->AddOrDie("24x48/transposed", data48);
store->AddOrDie("24x48/matrix/blocked32", data48);
store->AddOrDie("24x48/matrix/blocked48", data48);
return std::move(store);
}
// Expects that the |blocked_matrix| matches the |num_rows|, |num_columns|, and
// |block_size| and is filled with the |value|.
void ExpectBlockedMatrix(BlockedMatrix<float> blocked_matrix, size_t num_rows,
size_t num_columns, size_t block_size, float value) {
ASSERT_EQ(blocked_matrix.num_rows(), num_rows);
ASSERT_EQ(blocked_matrix.num_columns(), num_columns);
ASSERT_EQ(blocked_matrix.block_size(), block_size);
const std::vector<float> expected_vector(block_size, value);
for (size_t i = 0; i < blocked_matrix.num_vectors(); ++i) {
ExpectVector(blocked_matrix.vector(i), expected_vector);
}
}
// Tests that the flexible matrix wrapper passes through variables that don't
// end in the right suffix.
TEST(FlexibleMatrixVariableStoreWrapperTest, PassThroughIrrelevantVariables) {
FlexibleMatrixVariableStoreWrapper store(NewBlockedAndTransposedStore());
Vector<float> vector;
EXPECT_THAT(store.Lookup("missing", &vector),
test::IsErrorWithSubstr("Unknown variable"));
TF_ASSERT_OK(store.Lookup("1x1", &vector));
ExpectVector(vector, {1.0});
TF_EXPECT_OK(store.Close());
}
// Tests that the flexible matrix wrapper selects the plain matrix format for
// tiny matrices.
TEST(FlexibleMatrixVariableStoreWrapperTest, SelectPlainMatrixFormat) {
FlexibleMatrixVariableStoreWrapper store(NewBlockedAndTransposedStore());
Matrix<float> plain_matrix;
BlockedMatrix<float> blocked_matrix;
const string name =
tensorflow::strings::StrCat("1x1", FlexibleMatrixKernel::kSuffix);
EXPECT_THAT(store.Lookup(name, &blocked_matrix),
test::IsErrorWithSubstr("Sub-optimal matrix format"));
TF_ASSERT_OK(store.Lookup(name, &plain_matrix));
ExpectMatrix(plain_matrix, {{1.0}});
TF_EXPECT_OK(store.Close());
}
// Tests that the flexible matrix wrapper selects block size 32 for a matrix
// whose size is a multiple of 32.
TEST(FlexibleMatrixVariableStoreWrapperTest, SelectBlocked32MatrixFormat) {
FlexibleMatrixVariableStoreWrapper store(NewBlockedAndTransposedStore());
Matrix<float> plain_matrix;
BlockedMatrix<float> blocked_matrix;
const string name =
tensorflow::strings::StrCat("16x32", FlexibleMatrixKernel::kSuffix);
EXPECT_THAT(store.Lookup(name, &plain_matrix),
test::IsErrorWithSubstr("Sub-optimal matrix format"));
TF_ASSERT_OK(store.Lookup(name, &blocked_matrix));
ExpectBlockedMatrix(blocked_matrix, 16, 32, 32, 32.0);
TF_EXPECT_OK(store.Close());
}
// Tests that the flexible matrix wrapper selects block size 48 for a matrix
// whose size is a multiple of 48.
TEST(FlexibleMatrixVariableStoreWrapperTest, SelectBlocked48MatrixFormat) {
FlexibleMatrixVariableStoreWrapper store(NewBlockedAndTransposedStore());
Matrix<float> plain_matrix;
BlockedMatrix<float> blocked_matrix;
const string name =
tensorflow::strings::StrCat("24x48", FlexibleMatrixKernel::kSuffix);
EXPECT_THAT(store.Lookup(name, &plain_matrix),
test::IsErrorWithSubstr("Sub-optimal matrix format"));
TF_ASSERT_OK(store.Lookup(name, &blocked_matrix));
ExpectBlockedMatrix(blocked_matrix, 24, 48, 48, 48.0);
TF_EXPECT_OK(store.Close());
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
package(default_visibility = ["//visibility:public"])
# TODO(googleuser): Move XLA libs to dragnn/runtime when stable. Probably there
# should be a refactor with the Myelin libs since they are so similar.
load(
"//dragnn/runtime/xla:xla_build_defs.bzl",
"dragnn_xla_aot_components",
)
load(
"//dragnn/runtime:multiarch.bzl",
"dragnn_cc_multiarch_library",
"dragnn_cc_multiarch_test",
)
filegroup(
name = "test_xla_compilation_output",
srcs = glob(["testdata/xla_compilation_output/**"]),
)
cc_binary(
name = "xla_extract_config",
srcs = ["xla_extract_config.cc"],
deps = [
":xla_graph_utils",
"//dragnn/protos:export_proto_cc",
"@org_tensorflow//tensorflow/compiler/tf2xla:tf2xla_proto",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_binary(
name = "xla_extract_names_from_specs",
srcs = ["xla_extract_names_from_specs.cc"],
deps = [
":xla_spec_build_utils",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_library(
name = "xla_cell_converter",
srcs = ["xla_cell_converter.cc"],
hdrs = ["xla_cell_converter.h"],
deps = [
":xla_graph_utils",
"//dragnn/protos:export_proto_cc",
"//dragnn/runtime:trained_model",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "xla_cell_converter_test",
size = "small",
timeout = "moderate",
srcs = ["xla_cell_converter_test.cc"],
data = ["//dragnn/runtime:test_rnn_tagger"],
deps = [
":xla_cell_converter",
":xla_graph_utils",
":xla_spec_utils",
"//dragnn/components/syntaxnet:syntaxnet_component",
"//dragnn/core/test:generic",
"//dragnn/protos:export_proto_cc",
"//dragnn/runtime:alignment",
"//dragnn/runtime:trained_model",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/compiler/tf2xla:tf2xla_proto",
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_jit_compiled_cpu_function",
"@org_tensorflow//tensorflow/compiler/xla:shape_util",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "xla_compilation",
srcs = ["xla_compilation.cc"],
hdrs = ["xla_compilation.h"],
deps = [
":xla_cell_converter",
":xla_graph_utils",
":xla_spec_utils",
"//dragnn/protos:export_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime:component",
"//dragnn/runtime:trained_model",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "xla_compilation_test",
size = "small",
timeout = "moderate",
srcs = ["xla_compilation_test.cc"],
data = [
":test_xla_compilation_output",
"//dragnn/runtime:test_rnn_tagger",
],
deps = [
":xla_compilation",
":xla_spec_utils",
"//dragnn/components/syntaxnet:syntaxnet_component",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "xla_dynamic_component_base",
srcs = ["xla_dynamic_component_base.cc"],
hdrs = ["xla_dynamic_component_base.h"],
deps = [
":xla_spec_utils",
"//dragnn/protos:export_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime:alignment",
"//dragnn/runtime:component",
"//dragnn/runtime:extensions",
"//dragnn/runtime:fixed_embeddings",
"//dragnn/runtime:linked_embeddings",
"//dragnn/runtime:network_states",
"//dragnn/runtime:session_state",
"//dragnn/runtime:transition_system_traits",
"//dragnn/runtime:type_keyed_set",
"//dragnn/runtime:variable_store",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
"@org_tensorflow//tensorflow/compiler/xla:shape_util",
"@org_tensorflow//tensorflow/compiler/xla:xla_data_proto",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
dragnn_cc_multiarch_library(
name = "sequence_xla_dynamic_component_mixin",
hdrs = ["sequence_xla_dynamic_component_mixin.h"],
deps = [
":xla_dynamic_component_base",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime:extensions",
"//dragnn/runtime:network_states",
"//dragnn/runtime:sequence_features",
"//dragnn/runtime:sequence_links",
"//dragnn/runtime:sequence_model",
"//dragnn/runtime:session_state",
"//dragnn/runtime:variable_store",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "sequence_xla_dynamic_component_mixin_test",
size = "small",
srcs = ["sequence_xla_dynamic_component_mixin_test.cc"],
deps = [
":xla_dynamic_component",
":xla_graph_utils",
":xla_spec_utils",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:cell_trace_proto_cc",
"//dragnn/protos:export_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime:component",
"//dragnn/runtime:extensions",
"//dragnn/runtime:network_states",
"//dragnn/runtime:sequence_backend",
"//dragnn/runtime:sequence_extractor",
"//dragnn/runtime:sequence_predictor",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/compiler/xla:xla_data_proto",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "xla_aot_dynamic_component",
hdrs = ["xla_aot_dynamic_component.h"],
deps = [
":sequence_xla_dynamic_component_mixin",
":xla_dynamic_component_base",
":xla_graph_utils",
":xla_spec_utils",
"//dragnn/protos:export_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime:component",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/compiler/tf2xla:tf2xla_proto",
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
dragnn_cc_multiarch_library(
name = "xla_dynamic_component",
srcs = ["xla_dynamic_component.cc"],
deps = [
":sequence_xla_dynamic_component_mixin",
":xla_dynamic_component_base",
":xla_graph_utils",
":xla_spec_utils",
"//dragnn/core:compute_session",
"//dragnn/protos:export_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime:component",
"//dragnn/runtime:fixed_embeddings",
"//dragnn/runtime:linked_embeddings",
"//dragnn/runtime:network_states",
"//dragnn/runtime:session_state",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/compiler/tf2xla:tf2xla_proto",
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_jit_compiled_cpu_function",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "xla_dynamic_component_test",
size = "small",
srcs = ["xla_dynamic_component_test.cc"],
deps = [
":xla_dynamic_component",
":xla_graph_utils",
":xla_spec_utils",
"//dragnn/core/test:generic",
"//dragnn/protos:cell_trace_proto_cc",
"//dragnn/protos:export_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime:component",
"//dragnn/runtime:extensions",
"//dragnn/runtime:network_states",
"//dragnn/runtime:session_state",
"//dragnn/runtime:type_keyed_set",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:fake_variable_store",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/compiler/xla:xla_data_proto",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "xla_graph_utils",
srcs = ["xla_graph_utils.cc"],
hdrs = ["xla_graph_utils.h"],
deps = [
":xla_spec_utils",
"//dragnn/protos:export_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/compiler/tf2xla:tf2xla_proto",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "xla_graph_utils_test",
srcs = ["xla_graph_utils_test.cc"],
deps = [
":xla_graph_utils",
"//dragnn/core/test:generic",
"//dragnn/protos:export_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/compiler/tf2xla:tf2xla_proto",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "xla_spec_build_utils",
srcs = ["xla_spec_build_utils.cc"],
hdrs = ["xla_spec_build_utils.h"],
deps = [
":xla_spec_utils",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "xla_spec_build_utils_test",
srcs = ["xla_spec_build_utils_test.cc"],
deps = [
":xla_spec_build_utils",
"//dragnn/core/test:generic",
"//dragnn/protos:export_proto_cc",
"//dragnn/protos:spec_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "xla_spec_utils",
srcs = ["xla_spec_utils.cc"],
hdrs = ["xla_spec_utils.h"],
deps = [
"//dragnn/protos:export_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "xla_spec_utils_test",
srcs = ["xla_spec_utils_test.cc"],
deps = [
":xla_spec_utils",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_XLA_SEQUENCE_XLA_DYNAMIC_COMPONENT_MIXIN_H_
#define DRAGNN_RUNTIME_XLA_SEQUENCE_XLA_DYNAMIC_COMPONENT_MIXIN_H_
#include <stddef.h>
#include <string>
#include <type_traits>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_features.h"
#include "dragnn/runtime/sequence_links.h"
#include "dragnn/runtime/sequence_model.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "dragnn/runtime/xla/xla_dynamic_component_base.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A mixin that converts an XlaDynamicComponent variant into a sequence-based
// version. The |Base| must be a subclass of XlaDynamicComponentBase.
template <class Base>
class SequenceXlaDynamicComponentMixin : public Base {
public:
static_assert(std::is_base_of<XlaDynamicComponentBase, Base>::value,
"SequenceXlaDynamicComponentMixin must template on a subclass "
"of XlaDynamicComponentBase");
// Implements Component.
bool Supports(const ComponentSpec &component_spec,
const string &normalized_builder_name) const override;
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override;
private:
// Binds the fixed feature IDs for the |target_index|'th element of the
// |features| to the |instance|. Uses locals in the |network_states|.
void BindInputIds(const SequenceFeatures &features, int target_index,
const NetworkStates &network_states,
tensorflow::XlaCompiledCpuFunction *instance) const;
// Binds the linked embeddings for the |target_index|'th element in the
// |links| to the |instance|.
void BindInputLinks(const SequenceLinks &links, int target_index,
tensorflow::XlaCompiledCpuFunction *instance) const;
// Sequence-based model evaluator.
SequenceModel sequence_model_;
// Intermediate values used by sequence models.
SharedExtensionHandle<SequenceModel::EvaluateState> evaluate_state_handle_;
};
template <class Base>
bool SequenceXlaDynamicComponentMixin<Base>::Supports(
const ComponentSpec &component_spec,
const string &normalized_builder_name) const {
tensorflow::StringPiece name = normalized_builder_name;
return tensorflow::str_util::ConsumePrefix(&name, "Sequence") &&
Base::Supports(component_spec, name.ToString()) &&
SequenceModel::Supports(component_spec);
}
template <class Base>
tensorflow::Status SequenceXlaDynamicComponentMixin<Base>::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
// Initialize the base class first, so its FixedEmbeddingManager and
// LinkedEmbeddingManager can be wrapped in sequence-based versions.
TF_RETURN_IF_ERROR(Base::Initialize(component_spec, variable_store,
network_state_manager,
extension_manager));
TF_RETURN_IF_ERROR(sequence_model_.Initialize(
component_spec, Base::kLogitsName, &Base::fixed_embedding_manager(),
&Base::linked_embedding_manager(), network_state_manager));
extension_manager->GetShared(&evaluate_state_handle_);
return tensorflow::Status::OK();
}
template <class Base>
void SequenceXlaDynamicComponentMixin<Base>::BindInputIds(
const SequenceFeatures &features, int target_index,
const NetworkStates &network_states,
tensorflow::XlaCompiledCpuFunction *instance) const {
for (size_t channel_id = 0; channel_id < features.num_channels();
++channel_id) {
const MutableVector<int32> id_vector = network_states.GetLocal(
Base::fixed_embedding_manager().id_handle(channel_id, 0));
id_vector[0] = features.GetId(channel_id, target_index);
Base::BindInput(Vector<int32>(id_vector), Base::input_ids()[channel_id].id,
instance);
}
}
template <class Base>
void SequenceXlaDynamicComponentMixin<Base>::BindInputLinks(
const SequenceLinks &links, int target_index,
tensorflow::XlaCompiledCpuFunction *instance) const {
Vector<float> embedding;
bool is_out_of_bounds = false;
for (size_t channel_id = 0; channel_id < links.num_channels(); ++channel_id) {
links.Get(channel_id, target_index, &embedding, &is_out_of_bounds);
Base::BindInputLink(embedding, is_out_of_bounds,
Base::input_links()[channel_id], instance);
}
}
template <class Base>
tensorflow::Status SequenceXlaDynamicComponentMixin<Base>::Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const {
NetworkStates &network_states = session_state->network_states;
SequenceModel::EvaluateState &state =
session_state->extensions.Get(evaluate_state_handle_);
TF_RETURN_IF_ERROR(
sequence_model_.Preprocess(session_state, compute_session, &state));
// Avoid ComputeSession overhead by directly iterating over the feature IDs.
// Handle forward and reverse iteration via an index and increment.
int target_index = sequence_model_.left_to_right() ? 0 : state.num_steps - 1;
const int target_increment = sequence_model_.left_to_right() ? 1 : -1;
tensorflow::XlaCompiledCpuFunction &instance =
Base::GetInstance(session_state);
for (size_t step_index = 0; step_index < state.num_steps;
++step_index, target_index += target_increment) {
// Bind inputs and outputs into the |instance|.
BindInputIds(state.features, target_index, network_states, &instance);
BindInputLinks(state.links, target_index, &instance);
Base::BindInputRecurrences(step_index, network_states, &instance);
// Invoke the cell in the |instance|.
if (!instance.Run()) {
return tensorflow::errors::Internal("Error executing cell for ",
Base::name(), ": ",
instance.error_msg());
}
// Realizes the binding: copy outputs out of the |instance|.
Base::BindOutputLayers(step_index, network_states, &instance);
Base::MaybeTrace(step_index, &instance, component_trace);
}
return sequence_model_.Predict(network_states, &state);
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_XLA_SEQUENCE_XLA_DYNAMIC_COMPONENT_MIXIN_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/cell_trace.pb.h"
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/sequence_backend.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::Return;
constexpr int kVocabularySize = 123;
constexpr int kLogitsDim = 11;
constexpr int kNumSteps = 50;
// Sequence extractor that extracts [0, 2, 4, ...].
class EvenNumbers : public SequenceExtractor {
public:
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &) const override {
return true;
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *ids) const override {
ids->clear();
for (int i = 0; i < kNumSteps; ++i) ids->push_back(2 * i);
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(EvenNumbers);
// Trivial predictor that does nothing.
class NoPredictions : public SequencePredictor {
public:
// Implements SequenceLinker.
bool Supports(const ComponentSpec &) const override { return true; }
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status Predict(Matrix<float>, InputBatchCache *) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(NoPredictions);
class SequenceXlaDynamicComponentMixinTest : public NetworkTestBase {
protected:
SequenceXlaDynamicComponentMixinTest() {
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input_));
EXPECT_CALL(compute_session_, GetReadiedComponent(kTestComponentName))
.WillRepeatedly(Return(&backend_));
}
// Options for building a GraphDef file for tests. By default, this specifies
// a working GraphDef file, but settings can be perturbed to trigger errors.
struct GraphDefOptions {
GraphDefOptions() = default;
// Dimension of the classification logits.
int logits_dim = kLogitsDim;
// Name of the variable containing the classification logits.
string logits_name = "logits";
// Type of the feature ID input.
xla::PrimitiveType id_type = xla::S32;
// Dimension of the feature ID input.
int id_dim = 1;
};
// Builds and writes a simple frozen GraphDef file. By default it produces a
// valid frozen GraphDef, but arguments can be overridden for error testing.
// Returns the path to the file.
static string WriteFrozenGraphDef() {
return WriteFrozenGraphDef(GraphDefOptions());
}
static tensorflow::DataType TensorFlowType(xla::PrimitiveType type) {
switch (type) {
case xla::S32:
return tensorflow::DT_INT32;
case xla::S64:
return tensorflow::DT_INT64;
case xla::F32:
return tensorflow::DT_FLOAT;
default:
break;
}
return tensorflow::DT_INVALID;
}
static string WriteFrozenGraphDef(const GraphDefOptions &options) {
CellSubgraphSpec spec;
tensorflow::GraphDef graph;
// A fixed feature ID input.
auto *input = spec.add_input();
input->set_name("fixed_channel_0_index_0_ids");
input->set_tensor("cell/id:0");
input->set_type(CellSubgraphSpec::Input::TYPE_FEATURE);
// The retrieved embedding row, as logits.
auto *output = spec.add_output();
output->set_name(options.logits_name);
output->set_tensor("cell/lookup:0");
// Add CellSubgraphSpec node.
tensorflow::Tensor spec_tensor(tensorflow::DT_STRING,
tensorflow::TensorShape({1}));
spec.SerializeToString(&spec_tensor.vec<string>()(0));
tensorflow::TensorProto spec_tensor_proto;
spec_tensor.AsProtoField(&spec_tensor_proto);
TF_CHECK_OK(
tensorflow::NodeDefBuilder(kFrozenCellSubgraphSpecNodeName, "Const")
.Attr("dtype", tensorflow::DT_STRING)
.Attr("value", spec_tensor_proto)
.Attr("shape", tensorflow::TensorShape({1}))
.Finalize(graph.add_node()));
// Fixed feature ID input placeholder node.
TF_CHECK_OK(tensorflow::NodeDefBuilder("cell/id", "Placeholder")
.Attr("dtype", TensorFlowType(options.id_type))
.Attr("shape", tensorflow::TensorShape({options.id_dim}))
.Finalize(graph.add_node()));
// An embedding matrix constant. Each embedding is filled with its index.
tensorflow::Tensor embeddings(
tensorflow::DT_FLOAT,
tensorflow::TensorShape({kVocabularySize, options.logits_dim}));
auto raw_tensor = embeddings.tensor<float, 2>();
for (int row = 0; row < kVocabularySize; ++row) {
for (int column = 0; column < options.logits_dim; ++column) {
raw_tensor(row, column) = row;
}
}
tensorflow::TensorProto embeddings_proto;
embeddings.AsProtoTensorContent(&embeddings_proto);
TF_CHECK_OK(tensorflow::NodeDefBuilder("cell/embedding_matrix", "Const")
.Attr("dtype", tensorflow::DT_FLOAT)
.Attr("value", embeddings_proto)
.Finalize(graph.add_node()));
// A Gather op that looks up the |id| in the |embeddings|, and returns the
// result in the |logits|.
TF_CHECK_OK(tensorflow::NodeDefBuilder("cell/lookup", "Gather")
.Input("cell/embedding_matrix", 0, tensorflow::DT_FLOAT)
.Input("cell/id", 0, TensorFlowType(options.id_type))
.Attr("validate_indices", true)
.Finalize(graph.add_node()));
const string path =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "graph-frozen");
TF_CHECK_OK(SaveFrozenGraphDef(path, graph));
return path;
}
// Creates a component, initializes it based on the |component_spec_text| and
// |flow_path|, and evaluates it. The |component_trace| is overwritten with
// traces, if non-null. On error, returns non-OK.
tensorflow::Status Run(const string &component_spec_text = "",
const string &flow_path = WriteFrozenGraphDef(),
ComponentTrace *component_trace = nullptr) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
if (!component_spec.has_num_actions()) {
component_spec.set_num_actions(kLogitsDim);
}
component_spec.set_name(kTestComponentName);
auto *fixed_feature = component_spec.add_fixed_feature();
fixed_feature->set_embedding_dim(-1);
fixed_feature->set_size(1);
TF_RETURN_IF_ERROR(AddFrozenGraphDefResource(flow_path, &component_spec));
component_spec.mutable_backend()->set_registered_name("SequenceBackend");
auto &parameters =
*component_spec.mutable_component_builder()->mutable_parameters();
parameters["sequence_extractors"] = "EvenNumbers";
parameters["sequence_linkers"] = "";
parameters["sequence_predictor"] = "NoPredictions";
AddComponent(kTestComponentName);
TF_RETURN_IF_ERROR(
Component::CreateOrError("SequenceXlaDynamicComponent", &component_));
TF_RETURN_IF_ERROR(component_->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
network_states_.Reset(&network_state_manager_);
StartComponent(0); // XlaDynamicComponent will add steps
session_state_.extensions.Reset(&extension_manager_);
TF_RETURN_IF_ERROR(component_->Evaluate(&session_state_, &compute_session_,
component_trace));
return tensorflow::Status::OK();
}
// Input batch injected into Evaluate() by default.
InputBatchCache input_;
// Backend injected into Evaluate().
SequenceBackend backend_;
std::unique_ptr<Component> component_;
};
// Tests that XlaDynamicComponent fails if the spec uses attention.
TEST_F(SequenceXlaDynamicComponentMixinTest, UnsupportedAttention) {
EXPECT_THAT(Run("attention_component:'foo'"),
test::IsErrorWithSubstr("Attention is not supported"));
}
// Tests that XlaDynamicComponent fails if the spec has embedded fixed
// features.
TEST_F(SequenceXlaDynamicComponentMixinTest, InvalidFixedFeatureIsEmbedded) {
EXPECT_THAT(
Run("fixed_feature { embedding_dim:1 }"),
test::IsErrorWithSubstr("XLA requires non-embedded fixed features"));
}
// Tests that XlaDynamicComponent fails if the ComponentSpec has a fixed
// feature that does not appear in the graph.
TEST_F(SequenceXlaDynamicComponentMixinTest, InvalidFixedFeatureNotInGraph) {
EXPECT_THAT(
Run("fixed_feature { embedding_dim:-1 size:1 }"),
test::IsErrorWithSubstr(tensorflow::strings::StrCat(
"No XLA tensor named '", MakeXlaInputFixedFeatureIdName(1, 0), "'")));
}
// Tests that XlaDynamicComponent fails if the spec has multipled linked
// features.
TEST_F(SequenceXlaDynamicComponentMixinTest, InvalidLinkedFeatureIsMultiplied) {
EXPECT_THAT(
Run("linked_feature { embedding_dim:1 }"),
test::IsErrorWithSubstr("XLA requires non-multiplied linked features"));
}
// Tests that XlaDynamicComponent fails if the ComponentSpec has a linked
// feature that does not appear in the graph.
TEST_F(SequenceXlaDynamicComponentMixinTest, InvalidLinkedFeatureNotInGraph) {
const string kSpec = tensorflow::strings::StrCat(
"linked_feature { source_component:'", kTestComponentName,
"' source_layer:'logits' embedding_dim:-1 size:1 }");
EXPECT_THAT(Run(kSpec), test::IsErrorWithSubstr(tensorflow::strings::StrCat(
"No XLA tensor named '",
MakeXlaInputLinkedActivationVectorName(0), "'")));
}
// Tests that XlaDynamicComponent fails if the GraphDef file does not exist.
TEST_F(SequenceXlaDynamicComponentMixinTest, InvalidPath) {
EXPECT_THAT(Run("", "/invalid/path"),
test::IsErrorWithSubstr("No such file or directory"));
}
// Tests that XlaDynamicComponent fails if the logits dimension does not
// match ComponentSpec.num_actions.
TEST_F(SequenceXlaDynamicComponentMixinTest, WrongLogitsDimension) {
GraphDefOptions options;
options.logits_dim = kLogitsDim + 1;
EXPECT_THAT(Run("", WriteFrozenGraphDef(options)),
test::IsErrorWithSubstr(
"Dimension mismatch between classification logits"));
}
// Tests that XlaDynamicComponent fails if there is no "logits" layer.
TEST_F(SequenceXlaDynamicComponentMixinTest, WrongLogitsName) {
GraphDefOptions options;
options.logits_name = "not_logits";
EXPECT_THAT(Run("", WriteFrozenGraphDef(options)),
test::IsErrorWithSubstr("Unknown layer 'logits'"));
}
// Tests that XlaDynamicComponent fails to compile if one of the XLA
// tensors has the wrong type.
TEST_F(SequenceXlaDynamicComponentMixinTest, FailToCompile) {
GraphDefOptions options;
options.id_type = xla::F32;
EXPECT_THAT(
Run("", WriteFrozenGraphDef(options)),
test::IsErrorWithSubstr("float is not in the list of allowed values"));
}
// Tests that XlaDynamicComponent fails if one of the XLA tensors is not
// vector-like.
TEST_F(SequenceXlaDynamicComponentMixinTest, NotVectorLike) {
GraphDefOptions options;
options.id_dim = 2;
EXPECT_THAT(Run("", WriteFrozenGraphDef(options)),
test::IsErrorWithSubstr("XLA tensor has non-vector-like shape"));
}
// Tests that XlaDynamicComponent can run a simple non-deterministic frozen
// GraphDef.
TEST_F(SequenceXlaDynamicComponentMixinTest, SimpleNonDeterministicFlow) {
TF_ASSERT_OK(Run());
const Matrix<float> logits(GetLayer(kTestComponentName, "logits"));
ASSERT_EQ(logits.num_rows(), kNumSteps);
ASSERT_EQ(logits.num_columns(), kLogitsDim);
// Since each row of the embedding matrix is filled with its index, the logits
// should be equal to the feature IDs.
for (int step_index = 0; step_index < kNumSteps; ++step_index) {
ExpectVector(logits.row(step_index), kLogitsDim, 2 * step_index);
}
}
// Tests that XlaDynamicComponent can run a simple deterministic frozen
// GraphDef.
TEST_F(SequenceXlaDynamicComponentMixinTest, SimpleDeterministicFlow) {
GraphDefOptions options;
options.logits_dim = 1;
TF_ASSERT_OK(Run("num_actions:1", WriteFrozenGraphDef(options)));
}
// Tests that XlaDynamicComponent can run a simple frozen GraphDef with tracing
// enabled.
TEST_F(SequenceXlaDynamicComponentMixinTest, SimpleFlowWithTracing) {
ComponentTrace component_trace;
TF_ASSERT_OK(Run("", WriteFrozenGraphDef(), &component_trace));
// Each step trace should have a cell trace from the XLA instance.
ASSERT_EQ(component_trace.step_trace_size(), kNumSteps);
for (const ComponentStepTrace &step_trace : component_trace.step_trace()) {
// TODO(googleuser): Add once the JIT API supports this.
EXPECT_EQ(step_trace.ExtensionSize(CellTrace::step_trace_extension), 0);
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
name: "test_component"
fixed_feature {
embedding_dim: -1
size: 1
}
num_actions: 1
component_builder {
registered_name: "XlaAotDynamicComponent_model_v1_test_component"
}
[syntaxnet.dragnn.runtime.CompilationSpec.component_spec_extension] {
model_name: "model_v1"
cell_subgraph_spec {
input {
name: "fixed_channel_0_index_0_ids"
tensor: "cell/id:0"
type: TYPE_FEATURE
}
output {
name: "logits"
tensor: "cell/lookup:0"
}
}
}
feed {
id {
node_name: "cell/id"
}
shape {
dim {
size: 1
}
}
name: "INPUT__fixed_channel_0_index_0_ids"
}
fetch {
id {
node_name: "cell/lookup"
}
name: "OUTPUT__logits"
}
node {
name: "CellSubgraphSpec"
op: "Const"
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 1
}
}
string_val: "\n*\n\033fixed_channel_0_index_0_ids\022\tcell/id:0\030\001\022\027\n\006logits\022\rcell/lookup:0"
}
}
}
}
node {
name: "cell/id"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
}
node {
name: "cell/embedding_matrix"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 123
}
dim {
size: 1
}
}
tensor_content: "\000\000\000\000\000\000\200?\000\000\000@\000\000@@\000\000\200@\000\000\240@\000\000\300@\000\000\340@\000\000\000A\000\000\020A\000\000 A\000\0000A\000\000@A\000\000PA\000\000`A\000\000pA\000\000\200A\000\000\210A\000\000\220A\000\000\230A\000\000\240A\000\000\250A\000\000\260A\000\000\270A\000\000\300A\000\000\310A\000\000\320A\000\000\330A\000\000\340A\000\000\350A\000\000\360A\000\000\370A\000\000\000B\000\000\004B\000\000\010B\000\000\014B\000\000\020B\000\000\024B\000\000\030B\000\000\034B\000\000 B\000\000$B\000\000(B\000\000,B\000\0000B\000\0004B\000\0008B\000\000<B\000\000@B\000\000DB\000\000HB\000\000LB\000\000PB\000\000TB\000\000XB\000\000\\B\000\000`B\000\000dB\000\000hB\000\000lB\000\000pB\000\000tB\000\000xB\000\000|B\000\000\200B\000\000\202B\000\000\204B\000\000\206B\000\000\210B\000\000\212B\000\000\214B\000\000\216B\000\000\220B\000\000\222B\000\000\224B\000\000\226B\000\000\230B\000\000\232B\000\000\234B\000\000\236B\000\000\240B\000\000\242B\000\000\244B\000\000\246B\000\000\250B\000\000\252B\000\000\254B\000\000\256B\000\000\260B\000\000\262B\000\000\264B\000\000\266B\000\000\270B\000\000\272B\000\000\274B\000\000\276B\000\000\300B\000\000\302B\000\000\304B\000\000\306B\000\000\310B\000\000\312B\000\000\314B\000\000\316B\000\000\320B\000\000\322B\000\000\324B\000\000\326B\000\000\330B\000\000\332B\000\000\334B\000\000\336B\000\000\340B\000\000\342B\000\000\344B\000\000\346B\000\000\350B\000\000\352B\000\000\354B\000\000\356B\000\000\360B\000\000\362B\000\000\364B"
}
}
}
}
node {
name: "cell/lookup"
op: "Gather"
input: "cell/embedding_matrix"
input: "cell/id"
attr {
key: "Tindices"
value {
type: DT_INT32
}
}
attr {
key: "Tparams"
value {
type: DT_FLOAT
}
}
attr {
key: "validate_indices"
value {
b: true
}
}
}
component {
name: "rnn"
transition_system {
registered_name: "shift-only"
parameters {
key: "left_to_right"
value: "false"
}
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "words-embedding-input"
part {
file_format: "tf-records"
record_format: "syntaxnet.TokenEmbedding"
}
}
resource {
name: "words-vocab-input"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "char-ngram-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "word-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "frozen-graph"
part {
file_format: "proto"
record_format: "tensorflow.GraphDef"
}
}
fixed_feature {
name: "char_ngrams"
fml: "input.token { offset(-1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(0).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) }"
embedding_dim: -1
vocabulary_size: 25788
size: 3
}
fixed_feature {
name: "words"
fml: "input.token.word(min-freq=2)"
embedding_dim: -1
vocabulary_size: 23769
size: 1
}
network_unit {
registered_name: "LSTMNetwork"
parameters {
key: "hidden_layer_sizes"
value: "128"
}
parameters {
key: "omit_logits"
value: "true"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 1
attention_component: ""
component_builder {
registered_name: "XlaDynamicComponent"
}
}
component {
name: "tagger"
transition_system {
registered_name: "tagger"
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "tag-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "tag-to-category"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "frozen-graph"
part {
file_format: "proto"
record_format: "tensorflow.GraphDef"
}
}
linked_feature {
name: "recurrence"
fml: "bias(0)"
embedding_dim: -1
size: 1
source_component: "tagger"
source_translator: "history"
source_layer: "layer_0"
}
linked_feature {
name: "rnn"
fml: "input.focus"
embedding_dim: -1
size: 1
source_component: "rnn"
source_translator: "reverse-token"
source_layer: "layer_0"
}
network_unit {
registered_name: "FeedForwardNetwork"
parameters {
key: "hidden_layer_sizes"
value: "64,64"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 45
attention_component: ""
component_builder {
registered_name: "XlaDynamicComponent"
}
}
component {
name: "rnn"
transition_system {
registered_name: "shift-only"
parameters {
key: "left_to_right"
value: "false"
}
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "words-embedding-input"
part {
file_format: "tf-records"
record_format: "syntaxnet.TokenEmbedding"
}
}
resource {
name: "words-vocab-input"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "char-ngram-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "word-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "frozen-graph"
part {
file_format: "proto"
record_format: "tensorflow.GraphDef"
}
}
fixed_feature {
name: "char_ngrams"
fml: "input.token { offset(-1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(0).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) }"
embedding_dim: -1
vocabulary_size: 25788
size: 3
}
fixed_feature {
name: "words"
fml: "input.token.word(min-freq=2)"
embedding_dim: -1
vocabulary_size: 23769
size: 1
}
network_unit {
registered_name: "LSTMNetwork"
parameters {
key: "hidden_layer_sizes"
value: "128"
}
parameters {
key: "omit_logits"
value: "true"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 1
attention_component: ""
component_builder {
registered_name: "XlaDynamicComponent"
}
[syntaxnet.dragnn.runtime.CompilationSpec.component_spec_extension] {
model_name: "model_v1"
cell_subgraph_spec {
input {
name: "fixed_channel_0_index_0_ids"
tensor: "rnn/INPUT/fixed_channel_0_index_0_ids:0"
type: TYPE_FEATURE
}
input {
name: "fixed_channel_0_index_1_ids"
tensor: "rnn/INPUT/fixed_channel_0_index_1_ids:0"
type: TYPE_FEATURE
}
input {
name: "fixed_channel_0_index_2_ids"
tensor: "rnn/INPUT/fixed_channel_0_index_2_ids:0"
type: TYPE_FEATURE
}
input {
name: "fixed_channel_1_index_0_ids"
tensor: "rnn/INPUT/fixed_channel_1_index_0_ids:0"
type: TYPE_FEATURE
}
input {
name: "lstm_c"
tensor: "rnn/INPUT/lstm_c:0"
type: TYPE_RECURRENT
}
input {
name: "lstm_h"
tensor: "rnn/INPUT/lstm_h:0"
type: TYPE_RECURRENT
}
output {
name: "lstm_h"
tensor: "annotation/inference_rnn/rnn/lstm_h:0"
}
output {
name: "lstm_c"
tensor: "annotation/inference_rnn/rnn/lstm_c:0"
}
output {
name: "layer_0"
tensor: "annotation/inference_rnn/rnn/layer_0:0"
}
output {
name: "logits"
tensor: "annotation/inference_rnn/rnn/logits:0"
}
}
}
}
component {
name: "tagger"
transition_system {
registered_name: "tagger"
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "tag-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "tag-to-category"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_format: "text"
record_format: ""
}
}
resource {
name: "frozen-graph"
part {
file_format: "proto"
record_format: "tensorflow.GraphDef"
}
}
linked_feature {
name: "recurrence"
fml: "bias(0)"
embedding_dim: -1
size: 1
source_component: "tagger"
source_translator: "history"
source_layer: "layer_0"
}
linked_feature {
name: "rnn"
fml: "input.focus"
embedding_dim: -1
size: 1
source_component: "rnn"
source_translator: "reverse-token"
source_layer: "layer_0"
}
network_unit {
registered_name: "FeedForwardNetwork"
parameters {
key: "hidden_layer_sizes"
value: "64,64"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 45
attention_component: ""
component_builder {
registered_name: "XlaDynamicComponent"
}
[syntaxnet.dragnn.runtime.CompilationSpec.component_spec_extension] {
model_name: "model_v1"
cell_subgraph_spec {
input {
name: "linked_channel_0_activations"
tensor: "tagger/INPUT/linked_channel_0_activations:0"
type: TYPE_FEATURE
}
input {
name: "linked_channel_0_out_of_bounds"
tensor: "tagger/INPUT/linked_channel_0_out_of_bounds:0"
type: TYPE_FEATURE
}
input {
name: "linked_channel_1_activations"
tensor: "tagger/INPUT/linked_channel_1_activations:0"
type: TYPE_FEATURE
}
output {
name: "layer_0"
tensor: "annotation/inference_tagger/tagger/Relu:0"
}
output {
name: "layer_1"
tensor: "annotation/inference_tagger/tagger/Relu_1:0"
}
output {
name: "last_layer"
tensor: "annotation/inference_tagger/tagger/Relu_1:0"
}
output {
name: "logits"
tensor: "annotation/inference_tagger/tagger/logits:0"
}
}
}
}
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_XLA_XLA_AOT_DYNAMIC_COMPONENT_H_
#define DRAGNN_RUNTIME_XLA_XLA_AOT_DYNAMIC_COMPONENT_H_
#include <string>
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/xla/sequence_xla_dynamic_component_mixin.h"
#include "dragnn/runtime/xla/xla_dynamic_component_base.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// An XLA-based version of DynamicComponent using an XLA AOT compiled library.
//
// The class |AotCell| is generated by a tf_library build rule.
//
// The component class is instantiated in C++ code generated by a
// dragnn_xla_aot_components() build rule. The default constructor must set
// the model and component names to non-empty strings, and this must match
// the registered class name, as generated by RegisteredName().
//
// Example instantiation and registration:
//
// class XlaAotDynamicComponent_model_component
// : public XlaAotDynamicComponent<model::component> {
// public:
// XlaAotDynamicComponent_model_component()
// : XlaAotDynamicComponent<model::component>("model", "component") {}
// };
// DRAGNN_RUNTIME_REGISTER_COMPONENT(XlaAotDynamicComponent_model_component);
template <typename AotCell>
class XlaAotDynamicComponent : public XlaDynamicComponentBase {
protected:
XlaAotDynamicComponent(const string &model_name, const string &component_name)
: model_name_(model_name), component_name_(component_name) {}
// Unlike other specializations, this component will only be active if the
// spec is explicitly modified to support XLA AOT.
bool Supports(const ComponentSpec &spec,
const string &normalized_builder_name) const override {
// This must accept both the "base" XLA component and this one, based on how
// Supports is called repeatedly.
return (normalized_builder_name == "XlaDynamicComponent" ||
normalized_builder_name == RegisteredName()) &&
spec.name() == component_name_ &&
ModelNameForComponent(spec) == model_name_ &&
GetCellSubgraphSpecForComponent(spec, nullptr).ok();
}
bool PreferredTo(const Component &other) const override {
// AOT is preferred to JIT.
return true;
}
// Gets the frozen GraphDef using the |component_spec| and compiles it.
// The |cell_subgraph_spec| contained within it is filled in. On error,
// returns non-OK.
tensorflow::Status InitializeFromComponentSpec(
const ComponentSpec &component_spec,
CellSubgraphSpec *cell_subgraph_spec) override;
const tensorflow::XlaCompiledCpuFunction::StaticData &XlaStaticData()
const override {
return AotCell::StaticData();
}
private:
const string RegisteredName() const {
return tensorflow::strings::StrCat("XlaAotDynamicComponent_", model_name_,
"_", component_name_);
}
const string model_name_;
const string component_name_;
};
template <typename AotCell>
tensorflow::Status XlaAotDynamicComponent<AotCell>::InitializeFromComponentSpec(
const ComponentSpec &component_spec, CellSubgraphSpec *cell_subgraph_spec) {
LOG(INFO) << "Using XLA AOT library for model/component: " << model_name_
<< "/" << component_name_;
CHECK(!model_name_.empty() && !component_name_.empty());
return GetCellSubgraphSpecForComponent(component_spec, cell_subgraph_spec);
}
// Sequence-based version of the above.
template <typename AotCell>
using SequenceXlaAotDynamicComponent =
SequenceXlaDynamicComponentMixin<XlaAotDynamicComponent<AotCell>>;
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_XLA_XLA_AOT_DYNAMIC_COMPONENT_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_aot_dynamic_component.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/export.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
using ::testing::_;
using ::testing::InSequence;
using ::testing::Invoke;
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Fake AOT class suitable for testing initialization.
class TestComponent {
public:
static const tensorflow::XlaCompiledCpuFunction::StaticData &StaticData() {
static tensorflow::XlaCompiledCpuFunction::StaticData *kStaticData =
new tensorflow::XlaCompiledCpuFunction::StaticData;
return *kStaticData;
}
};
constexpr char kXlaModel[] = "TestModel";
constexpr char kXlaComponent[] = "TestComponent";
class XlaAotDynamicComponent_TestModel_TestComponent
: public XlaAotDynamicComponent<TestComponent> {
public:
XlaAotDynamicComponent_TestModel_TestComponent()
: XlaAotDynamicComponent<TestComponent>(kXlaModel, kXlaComponent) {}
using XlaAotDynamicComponent<TestComponent>::Supports;
using XlaAotDynamicComponent<TestComponent>::InitializeFromComponentSpec;
};
DRAGNN_RUNTIME_REGISTER_COMPONENT(
XlaAotDynamicComponent_TestModel_TestComponent);
class XlaAotDynamicComponentTest : public ::testing::Test {
public:
// Test util that builds a ComponentSpec with |component_name| set (if
// non-empty). A CompilationSpec extension contains |model_name| (if
// non-empty) and an empty CellSubgraphSpec if |include_subgraph_spec| is
// true. No extension is added if |model_name| is empty and
// |include_subgraph_spec| is false.
ComponentSpec BuildComponentSpec(const string &model_name,
const string &component_name,
bool include_subgraph_spec) {
ComponentSpec spec;
if (!component_name.empty()) spec.set_name(component_name);
// Add the extension if anything is in it.
if (!model_name.empty() || include_subgraph_spec) {
auto *compilation_spec =
spec.MutableExtension(CompilationSpec::component_spec_extension);
if (!model_name.empty()) compilation_spec->set_model_name(model_name);
if (include_subgraph_spec) {
CellSubgraphSpec cell_subgraph_spec;
*compilation_spec->mutable_cell_subgraph_spec() = cell_subgraph_spec;
}
}
return spec;
}
protected:
XlaAotDynamicComponent_TestModel_TestComponent component_;
};
TEST_F(XlaAotDynamicComponentTest, Supports) {
ComponentSpec spec = BuildComponentSpec(kXlaModel, kXlaComponent, true);
EXPECT_TRUE(component_.Supports(spec, "XlaDynamicComponent"));
EXPECT_TRUE(component_.Supports(
spec, "XlaAotDynamicComponent_TestModel_TestComponent"));
EXPECT_FALSE(component_.Supports(spec, "DynamicComponent"));
EXPECT_FALSE(component_.Supports(spec, "XlaAotDynamicComponent"));
EXPECT_FALSE(component_.Supports(
spec, "XlaAotDynamicComponent_TestModel_OtherComponent"));
}
TEST_F(XlaAotDynamicComponentTest, SupportRequiresMatchingModelName) {
EXPECT_FALSE(
component_.Supports(BuildComponentSpec("OtherModel", kXlaComponent, true),
"XlaDynamicComponent"));
EXPECT_FALSE(component_.Supports(BuildComponentSpec("", kXlaComponent, true),
"XlaDynamicComponent"));
}
TEST_F(XlaAotDynamicComponentTest, SupportRequiresSubgraph) {
EXPECT_FALSE(
component_.Supports(BuildComponentSpec(kXlaModel, kXlaComponent, false),
"XlaDynamicComponent"));
}
TEST_F(XlaAotDynamicComponentTest, InitializeFromComponentSpec) {
ComponentSpec component_spec;
auto *compilation_spec = component_spec.MutableExtension(
CompilationSpec::component_spec_extension);
// Example spec.
CellSubgraphSpec expected_cell_subgraph_spec;
auto *input = expected_cell_subgraph_spec.add_input();
input->set_name("fixed_channel_0_index_0_ids");
input->set_tensor("cell/id:0");
input->set_type(CellSubgraphSpec::Input::TYPE_FEATURE);
auto *output = expected_cell_subgraph_spec.add_output();
output->set_name("logits");
output->set_tensor("cell/lookup:0");
*compilation_spec->mutable_cell_subgraph_spec() = expected_cell_subgraph_spec;
CellSubgraphSpec actual_cell_subgraph_spec;
TF_ASSERT_OK(component_.InitializeFromComponentSpec(
component_spec, &actual_cell_subgraph_spec));
EXPECT_THAT(actual_cell_subgraph_spec,
test::EqualsProto(expected_cell_subgraph_spec));
}
TEST_F(XlaAotDynamicComponentTest, InitializeFromComponentSpecNeedsSubgraph) {
CellSubgraphSpec cell_subgraph_spec;
TF_EXPECT_OK(component_.InitializeFromComponentSpec(
BuildComponentSpec(kXlaModel, kXlaComponent, true), &cell_subgraph_spec));
EXPECT_THAT(component_.InitializeFromComponentSpec(
BuildComponentSpec(kXlaModel, kXlaComponent, false),
&cell_subgraph_spec),
test::IsErrorWithSubstr(
"Component TestComponent does not have a CellSubgraphSpec"));
}
// Tests using simple test AOT library.
constexpr int kNumSteps = 50;
constexpr int kVocabularySize = 123;
constexpr char kSimpleComponentSpecPath[] =
"dragnn/runtime/xla/testdata/simple-component-spec";
class XlaAotDynamicComponentRunTest : public NetworkTestBase {
public:
// Creates a component, initializes it based on the |component_spec|,
// and evaluates it. On error, returns non-OK.
tensorflow::Status Run(const ComponentSpec &component_spec) {
AddComponent(kTestComponentName);
TF_RETURN_IF_ERROR(Component::CreateOrError(
"XlaAotDynamicComponent_model_v1_test_component", &component_));
TF_RETURN_IF_ERROR(component_->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
network_states_.Reset(&network_state_manager_);
StartComponent(0);
session_state_.extensions.Reset(&extension_manager_);
TF_RETURN_IF_ERROR(
component_->Evaluate(&session_state_, &compute_session_, nullptr));
return tensorflow::Status::OK();
}
private:
std::unique_ptr<Component> component_;
};
// Test that runs a simple deterministic component.
TEST_F(XlaAotDynamicComponentRunTest, Simple) {
SetupTransitionLoop(kNumSteps);
EXPECT_CALL(compute_session_, AdvanceFromOracle(kTestComponentName))
.Times(kNumSteps);
{ // Extract a sequence of feature IDs equal to 2 * step_index.
ASSERT_LE(2 * kNumSteps, kVocabularySize);
InSequence scoped;
for (int step_index = 0; step_index < kNumSteps; ++step_index) {
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.WillOnce(Invoke(ExtractFeatures(0, {{2 * step_index, 1.0}})));
}
}
ComponentSpec component_spec;
TF_ASSERT_OK(tensorflow::ReadTextProto(
tensorflow::Env::Default(),
tensorflow::io::JoinPath(test::GetTestDataPrefix(),
kSimpleComponentSpecPath),
&component_spec));
TF_ASSERT_OK(Run(component_spec));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Build extension rules for XLA AOT compilation."""
load(
"//dragnn/runtime:multiarch.bzl",
"multiarch_name",
"MULTIARCH_CONFIGS",
)
load("@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
MULTIARCH_TFCOMPILE_FLAGS = {
"generic": [],
"avx": ["--target_features=+avx,+sse4.2"],
"avx2fma": ["--target_features=+avx,+avx2,+sse4.2,+fma"],
}
def _dragnn_xla_safe_name(name):
"""Generates a version of |name| is safe for use in C++."""
return name.replace('-','_').replace('.','_')
def _dragnn_xla_aot_library_name(arch, model, component):
"""Returns the AOT library name for the given model/component."""
return multiarch_name(model + '_' + component, arch)
def _dragnn_xla_aot_component_library_name(arch, model, component):
"""Returns the AOT component library name for the given model/component."""
return _dragnn_xla_aot_library_name(arch, model, component) + '_component'
def _dragnn_xla_config_proto(
name, graph,
config_tool = '//dragnn/runtime/xla:xla_extract_config'):
"""Extracts XLA Config from a frozen GraphDef for a DRAGNN component.
Generates a build target called |name| which is a text file that contains
a tensorflow.tf2xla.Config used in a tf_library build rule. The output
file is called "<name>.pbtxt".
Args:
name: The name of the build rule.
graph: The frozen tensorflow.GraphDef binary proto built for a particular
DRAGNN component by the runtime.
config_tool: The binary used to extract the Config proto. A non-default
can be passed when necessary.
"""
config_path = name + '.pbtxt'
native.genrule(
name=name,
srcs=[graph],
outs=[config_path],
tools=[config_tool],
cmd=('$(location ' + config_tool + ')' +
' $(location ' + graph + ')' +
' $(location ' + config_path + ')')
)
def _dragnn_xla_aot_component_cc_code(arch, model, component, target):
"""Generates C++ code for a component which wraps a particular AOT library.
Returns a string containing the generated C++ code that defines and registers
the DRAGNN component the implements a particular |model| and |component|,
targeted to a the given |arch|. The class name and registry name do not
include |arch|, which means only one can be linked in.
Args:
arch: The name of the target architecture.
model: The name of the DRAGNN model.
component: The name of the DRAGNN component that uses XLA AOT.
target: The directory that contains XLA AOT target.
Returns:
The string containing the generated C++ code.
"""
cc_template = """// GENERATED CODE.
#include "$TARGET/$MODEL_$COMPONENT_multiarch_$ARCH.h" // Generated by XLA.
#include "dragnn/runtime/xla/sequence_xla_dynamic_component_mixin.h"
#include "dragnn/runtime/xla/xla_aot_dynamic_component.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
class XlaAotDynamicComponent_$MODEL_$COMPONENT
: public XlaAotDynamicComponent<$MODEL::$COMPONENT> {
public:
XlaAotDynamicComponent_$MODEL_$COMPONENT()
: XlaAotDynamicComponent<$MODEL::$COMPONENT>("$MODEL", "$COMPONENT") {}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT(XlaAotDynamicComponent_$MODEL_$COMPONENT);
using SequenceXlaAotDynamicComponent_$MODEL_$COMPONENT =
SequenceXlaDynamicComponentMixin<XlaAotDynamicComponent_$MODEL_$COMPONENT>;
DRAGNN_RUNTIME_REGISTER_COMPONENT(
SequenceXlaAotDynamicComponent_$MODEL_$COMPONENT);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
"""
return cc_template.replace('$ARCH', arch).replace('$TARGET', target).replace(
'$MODEL', model).replace('$COMPONENT', component)
def _dragnn_xla_aot_component_library(arch, model, component,
tags=None, testonly=0):
"""Generates and compiles the component library that wraps the AOT binary.
Args:
arch: The name of the target architecture.
model: The name of the DRAGNN model.
component: The name of the DRAGNN component that uses XLA AOT.
tags: tags to apply to subsidiary build rules.
testonly: If 1, only testonly targets can depend on this target.
"""
xla_aot_library = _dragnn_xla_aot_library_name(arch, model, component)
xla_aot_component_library = _dragnn_xla_aot_component_library_name(
arch, model, component)
xla_aot_component_src = xla_aot_component_library + '.cc'
native.genrule(
name=xla_aot_component_library + '_cc',
outs=[xla_aot_component_src],
cmd = "cat << 'EOF' >$@\n{}\nEOF\n".format(
_dragnn_xla_aot_component_cc_code(
arch, model, component, native.package_name())
),
tags=tags,
testonly=testonly,
)
native.cc_library(
name=xla_aot_component_library,
srcs=[xla_aot_component_src],
deps = [
multiarch_name(
'//dragnn/runtime/xla:sequence_xla_dynamic_component_mixin',
arch),
multiarch_name(
'//dragnn/runtime/xla:xla_aot_dynamic_component',
arch),
':' + xla_aot_library,
],
testonly=testonly,
alwayslink=1
)
def _dragnn_xla_aot_library(name, arch, model, component, graph,
tags=None, testonly=0):
"""Runs tfcompile to AOT-compile a frozen GraphDef for a DRAGNN component.
Generates a build target called |name| which is a cc_library containing
the generated header and AOT-compiled function that implements a specific
DRAGNN component. For details on compilation see:
@org_tensorflow//tensorflow/compiler/aot/tfcompile.bzl
The generated library contains the following C++ class:
syntaxnet::dragnn::runtime::<model>::<component>
and the output file is called <name>.h
There is also build target called <name>-config which contains the
Config proto used by XLA.
Args:
name: The name of the build rule.
arch: The name of the target architecture.
model: The name of the DRAGNN model that contains this component.
component: The name of the DRAGNN component in the ComponentSpec.
graph: The frozen tensorflow.GraphDef binary proto built for a particular
DRAGNN component by the runtime.
tags: tags to apply to subsidiary build rules.
testonly: If 1, only testonly targets can depend on this target.
"""
# Gets the Config proto needed by tfcompile.
xla_config_name = name + '-config'
_dragnn_xla_config_proto(
name=xla_config_name,
graph=graph
)
# Runs tfcompile to AOT-compile the GraphDef.
tf_library(
name=_dragnn_xla_aot_library_name(arch, model, component),
graph=graph,
config=xla_config_name,
cpp_class='syntaxnet::dragnn::runtime::' + model + '::' + component,
tfcompile_flags = ' '.join([
'--gen_name_to_index=true',
'--gen_program_shape=true',
'--xla_cpu_multi_thread_eigen=false',
] + MULTIARCH_TFCOMPILE_FLAGS[arch]),
tags=tags,
testonly=testonly,
)
# Generates the component library that wraps the AOT library.
_dragnn_xla_aot_component_library(arch, model, component, tags, testonly)
def dragnn_xla_aot_components(name, component_data, tags=None, testonly=0):
"""Generates targets for all XLA AOT components in |component_data|.
Every element in the list |component_data| is also a list, which contains:
- name of the DRAGNN model;
- name of the component;
- relative path to the frozen GraphDef proto.
If multiple models exist in the same binary, the model name must uniquely
identify this specific model instance, e.g. 'parser_v20171101'.
Args:
name: The name of the build rule.
component_data: A list of per-component-data that is necessary to build
the AOT library and the component that wraps it.
tags: tags to apply to subsidiary build rules; the arch-specific tags
are included.
testonly: If 1, only testonly targets can depend on this target.
"""
safe_component_data = [
[
_dragnn_xla_safe_name(model),
_dragnn_xla_safe_name(component),
graph
]
for [model, component, graph] in component_data]
# Generates the AOT library and component targets.
for arch in MULTIARCH_TFCOMPILE_FLAGS:
for [model, component, graph_path] in safe_component_data:
_dragnn_xla_aot_library(
name=_dragnn_xla_aot_library_name(arch, model, component),
arch=arch,
model=model,
component=component,
graph=graph_path,
tags=(tags if tags else []) + MULTIARCH_CONFIGS[arch]['tags'],
testonly=testonly,
)
# Composes a library with all of the AOT library and component targets.
for arch in MULTIARCH_TFCOMPILE_FLAGS:
native.cc_library(
name=multiarch_name(name, arch),
deps = [
':' + _dragnn_xla_aot_component_library_name(
arch, model, component)
for [model, component, _] in safe_component_data
],
tags=(tags if tags else []) + MULTIARCH_CONFIGS[arch]['tags'],
testonly=testonly,
)
def dragnn_xla_aot_bazel_test(name, srcs):
"""Verifies that generated bzl matches what is checked in.
Passes when the generated file <name>_gen.bzl and the currently
existing one in <name>.bzl match.
Args:
name: The name of the bzl to test (without .bzl)
srcs: A set of MasterSpec files
"""
generated_bzl = name + '-gen.bzl'
native.genrule(
name=name + '_gen',
outs = [generated_bzl],
cmd = ('$(location '+
'//dragnn/runtime/xla:xla_extract_names_from_specs) ' +
native.package_name() + ' $(SRCS) $(OUTS)'),
tools=['//dragnn/runtime/xla:xla_extract_names_from_specs'],
srcs=srcs)
# Makes a copy of file_diff_test in this package.
native.genrule(
name = 'repackage_file_diff_test',
srcs = ['//dragnn/python:file_diff_test.py'],
outs = ['%s/file_diff_test.py' % native.package_name()],
cmd = 'cp $< $@',
)
# Compare the generated file.
expected_bzl = name + '.bzl'
native.py_test(
name = name,
srcs = ['%s/file_diff_test.py' % native.package_name()],
main = '%s/file_diff_test.py' % native.package_name(),
deps = ['//dragnn/python:file_diff_test'],
args = [
'--actual_file=$(location ' + generated_bzl + ')',
'--expected_file=$(location ' + expected_bzl + ')',
],
data = [expected_bzl, generated_bzl],
)
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_cell_converter.h"
#include <vector>
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns true if the |tensor_name| denotes a control dependency.
bool IsControlDependency(const string &tensor_name) {
return tensor_name[0] == '^';
}
// Returns the name of the node that supplies the input called |input_name|.
// This strips off any prefix on control dependencies and any suffix
// for specifying tensor output.
const string GetNodeNameFromInput(const string &input_name) {
return input_name.substr(IsControlDependency(input_name) ? 1 : 0,
input_name.rfind(':'));
}
// Returns true if the |node| is a TF variable.
bool IsVariableNode(const tensorflow::NodeDef &node) {
return node.op() == "VariableV2";
}
// Returns true if the |node| is skippable and can be changed
// to an Identity node.
bool IsNodeConvertibleToIdentity(const tensorflow::NodeDef &node) {
return node.op() == "Enter";
}
// Returns true if the node attribute with |name| is one that should always be
// retained, when a node is being simplified or frozen.
bool AlwaysKeepAttribute(const string &name) {
return name == "_output_shapes" || name == "T" || name == "dtype";
}
// Generates the name of the node that contains the serialized CellSubgraphSpec
// given a particular |component_name|.
string MakeCellSubgraphSpecNodeName(const string &component_name) {
return tensorflow::strings::StrCat(component_name,
"/EXPORT/CellSubgraphSpec");
}
// Loads the CellSubgraphSpec for the component named |component_name| from the
// |trained_model| into the |spec|. On error, returns non-OK.
tensorflow::Status LoadCellSubgraphSpec(const string &component_name,
const TrainedModel &trained_model,
CellSubgraphSpec *spec) {
const string tensor_name = MakeCellSubgraphSpecNodeName(component_name);
tensorflow::Tensor tensor;
TF_RETURN_IF_ERROR(trained_model.EvaluateTensor(tensor_name, &tensor));
if (!spec->ParseFromString(tensor.scalar<string>()())) {
return tensorflow::errors::InvalidArgument(
"Failed to parse CellSubgraphSpec for component ", component_name);
}
VLOG(1) << tensor_name << " = \n" << spec->DebugString();
return tensorflow::Status::OK();
}
} // namespace
tensorflow::Status XlaCellConverter::FillNode(
const tensorflow::NodeDef &src_node, tensorflow::NodeDef *dest_node) const {
dest_node->set_name(src_node.name());
dest_node->set_device(src_node.device());
if (IsNodeConvertibleToIdentity(src_node)) {
dest_node->set_op("Identity");
FillNodeAttributes(true, src_node, dest_node);
} else {
dest_node->set_op(src_node.op());
FillNodeAttributes(false, src_node, dest_node);
}
for (const string &input : src_node.input()) {
if (IsNodeInSubgraph(GetNodeNameFromInput(input))) {
dest_node->add_input(input);
}
}
return tensorflow::Status::OK();
}
tensorflow::Status XlaCellConverter::FreezeSpecNode(
const tensorflow::NodeDef &src_node, tensorflow::NodeDef *dest_node) const {
dest_node->set_name(kFrozenCellSubgraphSpecNodeName);
dest_node->set_op("Const");
FillNodeAttributes(true, src_node, dest_node);
tensorflow::Tensor tensor;
TF_RETURN_IF_ERROR(trained_model_->EvaluateTensor(
AsVariableName(TensorId(src_node.name(), 0)), &tensor));
// Leaves constants directly accessible, which allows for simple
// extraction of the value.
tensor.AsProtoField((*dest_node->mutable_attr())["value"].mutable_tensor());
return tensorflow::Status::OK();
}
tensorflow::Status XlaCellConverter::FreezeNode(
const tensorflow::NodeDef &src_node, tensorflow::NodeDef *dest_node) const {
dest_node->set_name(src_node.name());
dest_node->set_op("Const");
FillNodeAttributes(true, src_node, dest_node);
tensorflow::Tensor tensor;
TF_RETURN_IF_ERROR(trained_model_->EvaluateTensor(
AsVariableName(TensorId(src_node.name(), 0)), &tensor));
// Compactly stores tensor constants.
tensor.AsProtoTensorContent(
(*dest_node->mutable_attr())["value"].mutable_tensor());
return tensorflow::Status::OK();
}
void XlaCellConverter::FillNodeAttributes(bool restrict_attributes,
const tensorflow::NodeDef &src_node,
tensorflow::NodeDef *dest_node) {
for (const auto &attr : src_node.attr()) {
if (!restrict_attributes || AlwaysKeepAttribute(attr.first)) {
(*dest_node->mutable_attr())[attr.first] = attr.second;
}
}
}
bool XlaCellConverter::IsNodeInSubgraph(const string &node_name) const {
return operations_.find(node_name) != operations_.end();
}
tensorflow::Status XlaCellConverter::Convert(const string &component_name,
const TrainedModel &trained_model,
tensorflow::GraphDef *graph,
CellSubgraphSpec *spec) {
return XlaCellConverter().ConvertImpl(component_name, trained_model, graph,
spec);
}
tensorflow::Status XlaCellConverter::ConvertImpl(
const string &component_name, const TrainedModel &trained_model,
tensorflow::GraphDef *graph, CellSubgraphSpec *spec) {
component_name_ = component_name;
trained_model_ = &trained_model;
TF_RETURN_IF_ERROR(
LoadCellSubgraphSpec(component_name_, *trained_model_, spec));
TF_RETURN_IF_ERROR(BuildInputsAndOutputs(*spec));
TF_RETURN_IF_ERROR(BuildOperations());
graph->Clear();
const tensorflow::GraphDef *input_graph;
TF_RETURN_IF_ERROR(trained_model_->GraphDef(&input_graph));
// Adds in the CellSubgraphSpec node for this component.
const tensorflow::NodeDef *cell_subgraph_spec_node = nullptr;
TF_RETURN_IF_ERROR(trained_model_->LookupNode(
MakeCellSubgraphSpecNodeName(component_name_), &cell_subgraph_spec_node));
TF_RETURN_IF_ERROR(
FreezeSpecNode(*cell_subgraph_spec_node, graph->add_node()));
// Adds in frozen versions of the nodes needed for this cell.
for (const tensorflow::NodeDef &node : input_graph->node()) {
if (IsNodeInSubgraph(node.name())) {
if (IsVariableNode(node)) {
TF_RETURN_IF_ERROR(FreezeNode(node, graph->add_node()));
} else {
TF_RETURN_IF_ERROR(FillNode(node, graph->add_node()));
}
}
}
return tensorflow::Status::OK();
}
tensorflow::Status XlaCellConverter::BuildInputsAndOutputs(
const CellSubgraphSpec &spec) {
std::set<string> unique_input_names;
for (const CellSubgraphSpec::Input &input : spec.input()) {
if (!unique_input_names.insert(input.name()).second) {
return tensorflow::errors::InvalidArgument(
"Duplicate input name { ", input.ShortDebugString(), " }");
}
TensorId tensor_id;
TF_RETURN_IF_ERROR(ParseTensorId(input.tensor(), &tensor_id));
if (!inputs_.insert(tensor_id).second) {
return tensorflow::errors::InvalidArgument(
"Duplicate input variable { ", input.ShortDebugString(), " }");
}
}
std::set<string> unique_output_names;
for (const CellSubgraphSpec::Output &output : spec.output()) {
if (!unique_output_names.insert(output.name()).second) {
return tensorflow::errors::InvalidArgument(
"Duplicate output name { ", output.ShortDebugString(), " }");
}
TensorId tensor_id;
TF_RETURN_IF_ERROR(ParseTensorId(output.tensor(), &tensor_id));
outputs_.insert(tensor_id);
}
// Check that recurrent inputs match the name of an output.
for (const CellSubgraphSpec::Input &input : spec.input()) {
if (input.type() != CellSubgraphSpec::Input::TYPE_RECURRENT) continue;
if (unique_output_names.find(input.name()) == unique_output_names.end()) {
return tensorflow::errors::InvalidArgument(
"Recurrent input does not match any output { ",
input.ShortDebugString(), " }");
}
}
return tensorflow::Status::OK();
}
tensorflow::Status XlaCellConverter::BuildOperations() {
// Extract sets of input and output node names.
std::set<string> input_node_names;
std::set<string> output_node_names;
for (const TensorId &id : inputs_) input_node_names.insert(id.first);
for (const TensorId &id : outputs_) output_node_names.insert(id.first);
// Set of nodes that have already been visited by the DFS.
std::set<string> visited;
// DFS backwards from output nodes to input nodes and collect operations.
std::vector<string> stack(output_node_names.begin(), output_node_names.end());
while (!stack.empty()) {
const string name = stack.back();
stack.pop_back();
if (!visited.insert(name).second) continue; // already visited; skip
const tensorflow::NodeDef *node = nullptr;
TF_RETURN_IF_ERROR(trained_model_->LookupNode(name, &node));
Operation &operation = operations_[name];
if (operation.node != nullptr && operation.node != node) {
return tensorflow::errors::Internal("Inconsistent nodes for operation ",
name, " (", operation.node->name(),
" vs ", node->name());
}
operation.node = node;
// Function inputs bound the search; don't expand them.
if (input_node_names.find(name) != input_node_names.end()) continue;
// Expand (non-control) inputs.
for (const string &input_name : node->input()) {
if (IsControlDependency(input_name)) continue;
VLOG(1) << name << " has input " << input_name;
TensorId tensor_id;
TF_RETURN_IF_ERROR(ParseTensorId(input_name, &tensor_id));
stack.push_back(tensor_id.first);
}
}
return tensorflow::Status::OK();
}
tensorflow::Status XlaCellConverter::ParseTensorId(const string &tensor_name,
TensorId *tensor_id) {
return ParseTensorName(tensor_name, &tensor_id->first, &tensor_id->second);
}
string XlaCellConverter::AsVariableName(const TensorId &tensor_id) {
if (tensor_id.second == 0) return tensor_id.first;
return tensorflow::strings::StrCat(tensor_id.first, ":", tensor_id.second);
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_XLA_XLA_CELL_CONVERTER_H_
#define DRAGNN_RUNTIME_XLA_XLA_CELL_CONVERTER_H_
#include <map>
#include <set>
#include <string>
#include <utility>
#include "dragnn/protos/export.pb.h"
#include "dragnn/runtime/trained_model.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Converter that extracts the cell computation from a DRAGNN component and
// writes it as a frozen TF GraphDef.
//
// The trained model that contains the DRAGNN component must also contain a
// CellSubgraphSpec proto embedded into the TF graph as a specifically-named
// constant node (see runtime_support.py). The CellSubgraphSpec defines the
// boundaries of the cell comptation.
//
// Each frozen GraphDef contains a single function that runs the cell and
// is named after the component. The function inputs are reference
// variables, so they can be pointed at externally-managed pieces of memory,
// provided sufficient size and alignment. Output storage is managed by XLA.
// The function inputs and outputs are marked with special names, namely:
// INPUT__<CellSubgraphSpec.Input.name>
// OUTPUT__<CellSubgraphSpec.Output.name>
class XlaCellConverter {
public:
// Extracts the cell of the DRAGNN component named |component_name| from the
// |trained_model| and overwrites the |graph| with an equivalent
// TF GraphDef in |graph| which is frozen (it encapsulates Variables). The
// CellSubgraphSpec stored in the graph is copied into |spec|. On error,
// returns non-OK.
static tensorflow::Status Convert(const string &component_name,
const TrainedModel &trained_model,
tensorflow::GraphDef *graph,
CellSubgraphSpec *spec);
private:
// A (node_name, output_index) pair denoting a tensor.
using TensorId = std::pair<string, uint32>;
// A TF operation that makes up the cell.
struct Operation {
// The TF graph node represented by this operation.
const tensorflow::NodeDef *node = nullptr;
};
// Creates an empty converter.
XlaCellConverter() = default;
// Populates |dest_node| with the contents of |src_node|. For most nodes
// this is a complete copy. The exception is for nodes converted to Identity
// ops (e.g. Enter nodes). In this case, the op is changed to "Identity" and
// only critical attributes (for tensor type and shape) are retained.
tensorflow::Status FillNode(const tensorflow::NodeDef &src_node,
tensorflow::NodeDef *dest_node) const;
// Populates |dest_node| with the frozen contents of |src_node| which
// evaluates to a CellSubgraphSpec. The serialized contents will be
// stored in the value.tensor.string_val which makes extraction and
// development cleaner.
tensorflow::Status FreezeSpecNode(const tensorflow::NodeDef &src_node,
tensorflow::NodeDef *dest_node) const;
// Populates |dest_node| with the frozen contents of |src_node|. The
// output tensor for |src_node| will be evaluated and included as a
// constant in |dest_node|. On error, returns non-OK.
tensorflow::Status FreezeNode(const tensorflow::NodeDef &src_node,
tensorflow::NodeDef *dest_node) const;
// Copies over node attributes from |src_node| to |dest_node|, stripping out
// those which don't apply generally when |restrict_attributes| is true.
static void FillNodeAttributes(bool restrict_attributes,
const tensorflow::NodeDef &src_node,
tensorflow::NodeDef *dest_node);
// Returns true if a node called |node_name| is in the subgraph required
// for evaluating the cell.
bool IsNodeInSubgraph(const string &node_name) const;
// Implements the static Convert() method.
tensorflow::Status ConvertImpl(const string &component_name,
const TrainedModel &trained_model,
tensorflow::GraphDef *graph,
CellSubgraphSpec *spec);
// Populates the |inputs_| and |outputs_| based on the |spec|. On error,
// returns non-OK.
tensorflow::Status BuildInputsAndOutputs(const CellSubgraphSpec &spec);
// Walks from the |outputs_| to the |inputs_| in the |trained_model_|, adding
// to |operations_| along the way. Requires that BuildInputsAndOutputs() was
// called. On error, returns non-OK.
tensorflow::Status BuildOperations();
// Parses a |tensor_name| into a |tensor_id|. E.g.,
// "foo/bar:1" => ("foo/bar", 1)
// "baz" => ("baz", 0)
// On error, returns non-OK. It is an error if the |tensor_name| denotes a
// control dependency.
static tensorflow::Status ParseTensorId(const string &tensor_name,
TensorId *tensor_id);
// Returns the canonically-formatted name of the graph variable associated
// with the |tensor_id|.
static string AsVariableName(const TensorId &tensor_id);
// Name of the component being converted.
string component_name_;
// Trained model that contains the DRAGNN model.
const TrainedModel *trained_model_ = nullptr;
// Tensor ids that serve as inputs and outputs.
std::set<TensorId> inputs_;
std::set<TensorId> outputs_;
// Mapping from node name to Operation.
std::map<string, Operation> operations_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_XLA_XLA_CELL_CONVERTER_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/xla/xla_cell_converter.h"
#include <string.h>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/export.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/trained_model.h"
#include "dragnn/runtime/xla/xla_graph_utils.h"
#include "dragnn/runtime/xla/xla_spec_utils.h"
#include "syntaxnet/base.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Relative path to a saved model.
constexpr char kSavedModelDir[] = "dragnn/runtime/testdata/rnn_tagger";
// Names of components in the saved model.
const char *kComponentNames[] = {"rnn", "tagger"};
// Returns a valid saved model directory.
string GetSavedModelDir() {
return tensorflow::io::JoinPath(test::GetTestDataPrefix(), kSavedModelDir);
}
// Loads a trained model, converts each component to a frozen graph,
// compiles, and then runs the cell.
TEST(XlaCellConverterTest, LoadAndConvertAndRun) {
TrainedModel trained_model;
TF_ASSERT_OK(trained_model.Reset(GetSavedModelDir()));
for (const string component_name : kComponentNames) {
LOG(INFO) << "Component: " << component_name;
// Freezes the graph.
tensorflow::GraphDef graph_def;
CellSubgraphSpec spec_from_convert;
TF_ASSERT_OK(XlaCellConverter::Convert(component_name, trained_model,
&graph_def, &spec_from_convert));
LOG(INFO) << component_name << " graph nodes = " << graph_def.node_size();
// Extracts the CellSubgraphSpec and Config, then compiles.
CellSubgraphSpec cell_subgraph_spec;
tensorflow::tf2xla::Config xla_config;
TF_ASSERT_OK(
GetSpecAndMakeXlaConfig(graph_def, &cell_subgraph_spec, &xla_config));
EXPECT_THAT(cell_subgraph_spec, test::EqualsProto(spec_from_convert));
LOG(INFO) << component_name
<< " CellSubgraphSpec = " << cell_subgraph_spec.DebugString();
LOG(INFO) << component_name << " Config = " << xla_config.DebugString();
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<tensorflow::XlaJitCompiledCpuFunction> jit,
tensorflow::XlaJitCompiledCpuFunction::Compile(
graph_def, xla_config, xla::ExecutableBuildOptions()));
// Creates an instance which also allocates inputs.
tensorflow::XlaCompiledCpuFunction instance(jit->StaticData());
// Zeros out the inputs.
const auto *program_shape = instance.ProgramShape();
ASSERT_NE(nullptr, program_shape);
for (int i = 0; i < program_shape->parameters_size(); i++) {
const auto &shape = program_shape->parameters(i);
if (shape.element_type() != xla::OPAQUE) {
std::memset(instance.arg_data(i), 0, xla::ShapeUtil::ByteSizeOf(shape));
}
}
// This is just a "don't crash" test. XLA behavior will be exercised
// more thoroughly in regression tests.
LOG(INFO) << "Running " << component_name;
ASSERT_TRUE(instance.Run());
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment