Commit a4bb31d0 authored by Terry Koo's avatar Terry Koo
Browse files

Export @195097388.

parent dea7ecf6
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_H_
#define DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A variable store that groups all variables into a single byte array. This
// class and its subclasses are intended for use in production.
//
// Each variable occupies a sub-array of the main byte array. The mapping from
// the name and format of a variable to the sub-array containing its content is
// defined in ArrayVariableStoreSpec. The variables may appear in any order.
//
// This format allows variables to be mapped directly into memory, which reduces
// initialization time and supports usage on-device, where mmap() is effectively
// obligatory for large data resources.
class ArrayVariableStore : public VariableStore {
public:
// Creates an uninitialized store.
ArrayVariableStore() = default;
// Resets this to represent the variables defined by the |spec| and |data|.
// The |data| must remain valid until this is destroyed or Reset(). (Note
// that subclasses have simpler lifetime requirements). On error, returns
// non-OK and modifies nothing.
tensorflow::Status Reset(const ArrayVariableStoreSpec &spec,
AlignedView data);
// Implements VariableStore.
using VariableStore::Lookup; // import Lookup<T>() convenience methods
tensorflow::Status Lookup(const string &name, VariableSpec::Format format,
std::vector<size_t> *dimensions,
AlignedArea *area) override;
tensorflow::Status Close() override;
private:
friend class ArrayVariableStoreBuilder; // for access to kVersion
// The current version of the serialized format.
static const uint32 kVersion;
// A (name,format) key associated with a variable.
using Key = std::pair<string, VariableSpec::Format>;
// Dimension vector and aligned area.
using Value = std::pair<const std::vector<size_t>, AlignedArea>;
// Mapping from variable key to variable content. Initially null, filled in
// Reset(), and deleted in Close(). Wrapped in std::unique_ptr so the entire
// mapping can be deleted.
std::unique_ptr<std::map<Key, Value>> variables_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/array_variable_store_builder.h"
#include <stddef.h>
#include <tuple>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/array_variable_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/cpu_info.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Appends the content of the |view| to the |data|, followed by zero-padding to
// the next alignment boundary.
void Append(AlignedView view, string *data) {
DCHECK_EQ(PadToAlignment(data->size()), data->size());
const size_t alignment_padding = PadToAlignment(view.size()) - view.size();
data->append(view.data(), view.size());
data->append(alignment_padding, '\0');
}
// As above, but for an aligned |area|.
void Append(AlignedArea area, string *data) {
DCHECK_EQ(PadToAlignment(data->size()), data->size());
const size_t orig_size = data->size();
for (size_t i = 0; i < area.num_views(); ++i) Append(area.view(i), data);
DCHECK_EQ(data->size() - orig_size,
ComputeAlignedAreaSize(area.num_views(), area.view_size()));
}
} // namespace
tensorflow::Status ArrayVariableStoreBuilder::Build(
const Variables &variables, ArrayVariableStoreSpec *spec, string *data) {
data->clear();
spec->Clear();
spec->set_version(ArrayVariableStore::kVersion);
spec->set_alignment_bytes(internal::kAlignmentBytes);
spec->set_is_little_endian(tensorflow::port::kLittleEndian);
for (const auto &variable : variables) {
string name;
VariableSpec::Format format;
std::vector<size_t> dimensions;
AlignedArea area;
std::tie(name, format) = variable.first;
std::tie(dimensions, area) = variable.second;
if (format == VariableSpec::FORMAT_FLAT && area.num_views() != 1) {
return tensorflow::errors::InvalidArgument(
"Flat variables must have 1 view, but '", name, "' has ",
area.num_views());
}
VariableSpec *variable_spec = spec->add_variable();
variable_spec->set_name(name);
variable_spec->set_format(format);
variable_spec->set_num_views(area.num_views());
variable_spec->set_view_size(area.view_size());
for (size_t dimension : dimensions) {
variable_spec->add_dimension(dimension);
}
Append(area, data);
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_BUILDER_H_
#define DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_BUILDER_H_
#include <map>
#include <string>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/variable_store_wrappers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Utils for converting a set of variables into a byte array that can be loaded
// by ArrayVariableStore. See that class for details on the required format.
class ArrayVariableStoreBuilder {
public:
using Variables = CaptureUsedVariableStoreWrapper::Variables;
// Forbids instantiation; pure static class.
ArrayVariableStoreBuilder() = delete;
~ArrayVariableStoreBuilder() = delete;
// Overwrites the |data| with a byte array that represents the |variables|,
// and overwrites the |spec| with the associated configuration. On error,
// returns non-OK.
static tensorflow::Status Build(const Variables &variables,
ArrayVariableStoreSpec *spec, string *data);
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_ARRAY_VARIABLE_STORE_BUILDER_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/array_variable_store_builder.h"
#include <stddef.h>
#include <map>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/test/helpers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Tests that the builder rejects invalid flat variables.
TEST(ArrayVariableStoreBuilderTest, InvalidFlatVariable) {
AlignedView view;
ArrayVariableStoreBuilder::Variables variables;
ArrayVariableStoreSpec spec;
string data;
TF_ASSERT_OK(view.Reset(nullptr, 2 * internal::kAlignmentBytes));
// Try an empty area.
std::pair<string, VariableSpec::Format> foo_key("foo",
VariableSpec::FORMAT_FLAT);
AlignedArea area;
TF_ASSERT_OK(area.Reset(view, 0, 0));
std::pair<std::vector<size_t>, AlignedArea> foo_value({1}, area);
variables.push_back(std::make_pair(foo_key, foo_value));
EXPECT_THAT(ArrayVariableStoreBuilder::Build(variables, &spec, &data),
test::IsErrorWithSubstr(
"Flat variables must have 1 view, but 'foo' has 0"));
// Try an area with more than 1 sub-view.
TF_ASSERT_OK(area.Reset(view, 2, 0));
variables[0].second.second = area;
EXPECT_THAT(ArrayVariableStoreBuilder::Build(variables, &spec, &data),
test::IsErrorWithSubstr(
"Flat variables must have 1 view, but 'foo' has 2"));
}
// Tests that the builder succeeds on good inputs and reproduces an expected
// byte array.
//
// NB: Since this test directly compares the byte array, it implicitly requires
// that the builder lays out the variables in a particular order. If that order
// changes, the test expectations must be updated.
TEST(ArrayVariableStoreBuilderTest, RegressionTest) {
const string kLocalSpecPath =
"dragnn/runtime/testdata/array_variable_store_spec";
const string kLocalDataPath =
"dragnn/runtime/testdata/array_variable_store_data";
const string kExpectedSpecPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(),
"dragnn/runtime/testdata/array_variable_store_spec");
const string kExpectedDataPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(),
"dragnn/runtime/testdata/array_variable_store_data");
// If these values are changed, make sure to rewrite the test data and update
// array_variable_store_test.cc.
UniqueMatrix<float> foo({{0.0, 0.5, 1.0}, //
{1.5, 2.0, 2.5}, //
{3.0, 3.5, 4.0}, //
{4.5, 5.0, 5.5}});
UniqueMatrix<double> baz_data({{1.0, 2.0, 2.0, 2.0}, //
{3.0, 4.0, 4.0, 4.0}, //
{5.0, 6.0, 6.0, 6.0}, //
{7.0, 8.0, 8.0, 8.0}});
ArrayVariableStoreBuilder::Variables variables;
std::pair<string, VariableSpec::Format> foo_key(
"foo", VariableSpec::FORMAT_ROW_MAJOR_MATRIX);
std::pair<std::vector<size_t>, AlignedArea> foo_value(
{foo->num_rows(), foo->num_columns()}, AlignedArea(foo.area()));
variables.push_back(std::make_pair(foo_key, foo_value));
std::pair<string, VariableSpec::Format> baz_key(
"baz", VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX);
std::pair<std::vector<size_t>, AlignedArea> baz_value(
{2, 8, 4}, AlignedArea(baz_data.area()));
variables.push_back(std::make_pair(baz_key, baz_value));
ArrayVariableStoreSpec actual_spec;
actual_spec.set_version(999);
string actual_data = "garbage to be overwritten";
TF_ASSERT_OK(
ArrayVariableStoreBuilder::Build(variables, &actual_spec, &actual_data));
if (false) {
// Rewrite the test data.
TF_CHECK_OK(tensorflow::WriteTextProto(tensorflow::Env::Default(),
kLocalSpecPath, actual_spec));
TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
kLocalDataPath, actual_data));
} else {
// Compare to the test data.
ArrayVariableStoreSpec expected_spec;
string expected_data;
TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(),
kExpectedSpecPath, &expected_spec));
TF_CHECK_OK(tensorflow::ReadFileToString(
tensorflow::Env::Default(), kExpectedDataPath, &expected_data));
EXPECT_THAT(actual_spec, test::EqualsProto(expected_spec));
EXPECT_EQ(actual_data, expected_data);
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/array_variable_store.h"
#include <string.h>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/file_array_variable_store.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/mmap_array_variable_store.h"
#include "dragnn/runtime/test/helpers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
template <class T>
void ExpectBlockedData(BlockedMatrix<T> matrix,
const std::vector<std::vector<T>> &data) {
EXPECT_EQ(matrix.num_vectors(), data.size());
// The indices don't really have semantic names, so we just use `i` and `j`.
// See BlockedMatrixFormat for details.
for (int i = 0; i < matrix.num_vectors(); ++i) {
EXPECT_EQ(matrix.block_size(), data[i].size());
for (int j = 0; j < data[i].size(); ++j) {
EXPECT_EQ(matrix.vector(i)[j], data[i][j]);
}
}
}
// Returns an ArrayVariableStoreSpec parsed from the |text|.
ArrayVariableStoreSpec MakeSpec(const string &text) {
ArrayVariableStoreSpec spec;
CHECK(TextFormat::ParseFromString(text, &spec));
return spec;
}
// Returns an ArrayVariableStoreSpec that has proper top-level settings and
// whose variables are parsed from the |variables_text|.
ArrayVariableStoreSpec MakeSpecWithVariables(const string &variables_text) {
return MakeSpec(tensorflow::strings::StrCat(
"version: 0 alignment_bytes: ", internal::kAlignmentBytes,
" is_little_endian: ", tensorflow::port::kLittleEndian, " ",
variables_text));
}
// Tests that kLittleEndian actually means little-endian.
TEST(ArrayVariableStoreTest, EndianDetection) {
static_assert(sizeof(uint32) == 4 * sizeof(uint8), "Unexpected int sizes");
const uint32 foo = 0xdeadbeef;
uint8 foo_bytes[4];
memcpy(foo_bytes, &foo, 4 * sizeof(uint8));
if (tensorflow::port::kLittleEndian) {
EXPECT_EQ(foo_bytes[3], 0xde);
EXPECT_EQ(foo_bytes[2], 0xad);
EXPECT_EQ(foo_bytes[1], 0xbe);
EXPECT_EQ(foo_bytes[0], 0xef);
} else {
EXPECT_EQ(foo_bytes[0], 0xde);
EXPECT_EQ(foo_bytes[1], 0xad);
EXPECT_EQ(foo_bytes[2], 0xbe);
EXPECT_EQ(foo_bytes[3], 0xef);
}
}
// Tests that the store checks for missing fields.
TEST(ArrayVariableStoreTest, MissingRequiredField) {
for (const string kSpec :
{"version: 0 alignment_bytes: 0", "version: 0 is_little_endian: true",
"alignment_bytes: 0 is_little_endian: true"}) {
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpec(kSpec), AlignedView()),
test::IsErrorWithSubstr(
"ArrayVariableStoreSpec is missing a required field"));
}
}
// Tests that the store checks for a matching version number.
TEST(ArrayVariableStoreTest, VersionMismatch) {
const string kSpec = "version: 999 alignment_bytes: 0 is_little_endian: true";
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpec(kSpec), AlignedView()),
test::IsErrorWithSubstr("ArrayVariableStoreSpec.version (999) "
"does not match the binary (0)"));
}
// Tests that the store checks for a matching alignment requirement.
TEST(ArrayVariableStoreTest, AlignmentMismatch) {
const string kSpec = "version: 0 alignment_bytes: 1 is_little_endian: true";
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpec(kSpec), AlignedView()),
test::IsErrorWithSubstr(tensorflow::strings::StrCat(
"ArrayVariableStoreSpec.alignment_bytes (1) does not match "
"the binary (", internal::kAlignmentBytes, ")")));
}
// Tests that the store checks for matching endian-ness.
TEST(ArrayVariableStoreTest, EndiannessMismatch) {
const string kSpec = tensorflow::strings::StrCat(
"version: 0 alignment_bytes: ", internal::kAlignmentBytes,
" is_little_endian: ", !tensorflow::port::kLittleEndian);
ArrayVariableStore store;
EXPECT_THAT(
store.Reset(MakeSpec(kSpec), AlignedView()),
test::IsErrorWithSubstr(tensorflow::strings::StrCat(
"ArrayVariableStoreSpec.is_little_endian (",
!tensorflow::port::kLittleEndian, ") does not match the binary (",
tensorflow::port::kLittleEndian, ")")));
}
// Tests that the store rejects FORMAT_UNKNOWN variables.
TEST(ArrayVariableStoreTest, RejectFormatUnknown) {
const string kVariables = "variable { format: FORMAT_UNKNOWN }";
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpecWithVariables(kVariables), AlignedView()),
test::IsErrorWithSubstr("Unknown variable format"));
}
// Tests that the store rejects FORMAT_FLAT variables with too few sub-views.
TEST(ArrayVariableStoreTest, TooFewViewsForFlatVariable) {
const string kVariables = "variable { format: FORMAT_FLAT num_views: 0 }";
ArrayVariableStore store;
EXPECT_THAT(
store.Reset(MakeSpecWithVariables(kVariables), AlignedView()),
test::IsErrorWithSubstr("Flat variables must have 1 view"));
}
// Tests that the store rejects FORMAT_FLAT variables with too many sub-views.
TEST(ArrayVariableStoreTest, TooManyViewsForFlatVariable) {
const string kVariables = "variable { format: FORMAT_FLAT num_views: 2 }";
ArrayVariableStore store;
EXPECT_THAT(
store.Reset(MakeSpecWithVariables(kVariables), AlignedView()),
test::IsErrorWithSubstr("Flat variables must have 1 view"));
}
// Tests that the store accepts FORMAT_ROW_MAJOR_MATRIX variables with one
// sub-view.
TEST(ArrayVariableStoreTest, MatrixWithOneRow) {
const string kVariables =
"variable { format: FORMAT_ROW_MAJOR_MATRIX num_views: 1 view_size: 0 }";
ArrayVariableStore store;
TF_EXPECT_OK(store.Reset(MakeSpecWithVariables(kVariables), AlignedView()));
}
// Tests that the store rejects variables that overrun the main byte array.
TEST(ArrayVariableStoreTest, VariableOverrunsMainByteArray) {
const string kVariables =
"variable { format: FORMAT_FLAT num_views: 1 view_size: 1024 }";
AlignedView data;
TF_ASSERT_OK(data.Reset(nullptr, 1023));
ArrayVariableStore store;
EXPECT_THAT(
store.Reset(MakeSpecWithVariables(kVariables), data),
test::IsErrorWithSubstr("Variable would overrun main byte array"));
}
// Tests that the store rejects duplicate variables.
TEST(ArrayVariableStoreTest, DuplicateVariables) {
const string kVariables = R"(
variable { name: 'x' format: FORMAT_FLAT num_views: 1 view_size: 1024 }
variable { name: 'y' format: FORMAT_FLAT num_views: 1 view_size: 2048 }
variable { name: 'x' format: FORMAT_FLAT num_views: 1 view_size: 4096 }
)";
AlignedView data;
TF_ASSERT_OK(data.Reset(nullptr, 1 << 20)); // 1MB
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpecWithVariables(kVariables), data),
test::IsErrorWithSubstr("Duplicate variable"));
}
// Tests that the store rejects sets of variables that do not completely cover
// the main byte array.
TEST(ArrayVariableStoreTest, LeftoverBytesInMainByteArray) {
const string kVariables = R"(
variable { name: 'x' format: FORMAT_FLAT num_views: 1 view_size: 1024 }
variable { name: 'y' format: FORMAT_FLAT num_views: 1 view_size: 2048 }
variable { name: 'z' format: FORMAT_FLAT num_views: 1 view_size: 4096 }
)";
AlignedView data;
TF_ASSERT_OK(data.Reset(nullptr, 1 << 20)); // 1MB
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpecWithVariables(kVariables), data),
test::IsErrorWithSubstr(
"Variables do not completely cover main byte array"));
}
// The fast matrix-vector routines do not support padding.
TEST(ArrayVariableStoreTest, PaddingInBlockedMatrix) {
const string kVariables = R"(
variable {
name: "baz"
format: FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX
num_views: 4
view_size: 16
dimension: 2
dimension: 4
dimension: 2
}
)";
AlignedView data;
TF_ASSERT_OK(data.Reset(nullptr, 1 << 20)); // 1MB
ArrayVariableStore store;
EXPECT_THAT(store.Reset(MakeSpecWithVariables(kVariables), data),
test::IsErrorWithSubstr(
"Currently, fast matrix-vector operations do not support "
"padded blocked matrices"));
}
// Tests that the store cannot retrieve variables when it is uninitialized.
TEST(ArrayVariableStoreTest, LookupWhenUninitialized) {
ArrayVariableStore store;
Vector<float> vector;
EXPECT_THAT(store.Lookup("foo", &vector),
test::IsErrorWithSubstr("ArrayVariableStore not initialized"));
}
// Tests that the store can use an empty byte array when there are no variables.
TEST(ArrayVariableStoreTest, EmptyByteArrayWorksIfNoVariables) {
ArrayVariableStore store;
TF_EXPECT_OK(store.Reset(MakeSpecWithVariables(""), AlignedView()));
// The store contains nothing.
Vector<float> vector;
EXPECT_THAT(
store.Lookup("foo", &vector),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with name "
"'foo' and format FORMAT_FLAT"));
}
// Tests that the store fails if it is closed before it has been initialized.
TEST(ArrayVariableStoreTest, CloseBeforeReset) {
ArrayVariableStore store;
EXPECT_THAT(store.Close(),
test::IsErrorWithSubstr("ArrayVariableStore not initialized"));
}
// Tests that the store can be closed (once) after it has been initialized.
TEST(ArrayVariableStoreTest, CloseAfterReset) {
ArrayVariableStore store;
TF_ASSERT_OK(store.Reset(MakeSpecWithVariables(""), AlignedView()));
TF_EXPECT_OK(store.Close());
// Closing twice is still an error.
EXPECT_THAT(store.Close(),
test::IsErrorWithSubstr("ArrayVariableStore not initialized"));
}
// Templated on an ArrayVariableStore subclass.
template <class Subclass>
class ArrayVariableStoreSubclassTest : public ::testing::Test {};
typedef ::testing::Types<FileArrayVariableStore, MmapArrayVariableStore>
Subclasses;
TYPED_TEST_CASE(ArrayVariableStoreSubclassTest, Subclasses);
// Tests that the store fails to load a non-existent file.
TYPED_TEST(ArrayVariableStoreSubclassTest, NonExistentFile) {
// Paths to the spec and data produced by array_variable_store_builder_test.
const string kDataPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/runtime/testdata/non_existent_file");
TypeParam store;
EXPECT_THAT(store.Reset(MakeSpecWithVariables(""), kDataPath),
test::IsErrorWithSubstr(""));
}
// Tests that the store can load an empty file if there are no variables.
TYPED_TEST(ArrayVariableStoreSubclassTest, EmptyFile) {
// Paths to the spec and data produced by array_variable_store_builder_test.
const string kDataPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/runtime/testdata/empty_file");
TypeParam store;
TF_ASSERT_OK(store.Reset(MakeSpecWithVariables(""), kDataPath));
Vector<float> vector;
Matrix<float> row_major_matrix;
EXPECT_THAT(store.Lookup("foo", &vector),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with "
"name 'foo' and format FORMAT_FLAT"));
EXPECT_THAT(
store.Lookup("bar", &row_major_matrix),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with name "
"'bar' and format FORMAT_ROW_MAJOR_MATRIX"));
}
// Tests that the store, when loading a pre-built byte array, produces the same
// variables that the builder converted.
TYPED_TEST(ArrayVariableStoreSubclassTest, RegressionTest) {
// Paths to the spec and data produced by array_variable_store_builder_test.
const string kSpecPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(),
"dragnn/runtime/testdata/array_variable_store_spec");
const string kDataPath = tensorflow::io::JoinPath(
test::GetTestDataPrefix(),
"dragnn/runtime/testdata/array_variable_store_data");
ArrayVariableStoreSpec spec;
TF_CHECK_OK(
tensorflow::ReadTextProto(tensorflow::Env::Default(), kSpecPath, &spec));
TypeParam store;
TF_ASSERT_OK(store.Reset(spec, kDataPath));
Matrix<float> foo;
TF_ASSERT_OK(store.Lookup("foo", &foo));
// NB: These assertions must be kept in sync with the variables defined in
// array_variable_store_builder_test.cc.
ExpectMatrix(foo, {{0.0, 0.5, 1.0}, //
{1.5, 2.0, 2.5}, //
{3.0, 3.5, 4.0}, //
{4.5, 5.0, 5.5}});
// Blocked formats.
BlockedMatrix<double> baz;
TF_ASSERT_OK(store.Lookup("baz", &baz));
EXPECT_EQ(baz.num_rows(), 2);
EXPECT_EQ(baz.num_columns(), 8);
EXPECT_EQ(baz.block_size(), 4);
ExpectBlockedData(baz, {{1.0, 2.0, 2.0, 2.0}, //
{3.0, 4.0, 4.0, 4.0}, //
{5.0, 6.0, 6.0, 6.0}, //
{7.0, 8.0, 8.0, 8.0}});
// Try versions of "foo" and "baz" with the wrong format.
Vector<float> vector;
Matrix<float> row_major_matrix;
EXPECT_THAT(store.Lookup("foo", &vector),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with "
"name 'foo' and format FORMAT_FLAT"));
EXPECT_THAT(store.Lookup("baz", &vector),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with "
"name 'baz' and format FORMAT_FLAT"));
EXPECT_THAT(
store.Lookup("baz", &row_major_matrix),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with name "
"'baz' and format FORMAT_ROW_MAJOR_MATRIX"));
// Try totally unknown variables.
EXPECT_THAT(store.Lookup("missing", &vector),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with "
"name 'missing' and format FORMAT_FLAT"));
EXPECT_THAT(
store.Lookup("missing", &row_major_matrix),
test::IsErrorWithSubstr("ArrayVariableStore has no variable with name "
"'missing' and format FORMAT_ROW_MAJOR_MATRIX"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/attributes.h"
#include <set>
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status Attributes::Reset(
const tensorflow::protobuf::Map<string, string> &mapping) {
// First pass: Parse each value in the |mapping|.
for (const auto &name_value : mapping) {
const string &name = name_value.first;
const string &value = name_value.second;
const auto it = attributes_.find(name);
if (it == attributes_.end()) {
return tensorflow::errors::InvalidArgument("Unknown attribute: ", name);
}
TF_RETURN_IF_ERROR(it->second->Parse(value));
}
// Second pass: Look for missing mandatory attributes.
std::set<string> missing_mandatory_attributes;
for (const auto &it : attributes_) {
const string &name = it.first;
Attribute *attribute = it.second;
if (!attribute->IsMandatory()) continue;
if (mapping.find(name) == mapping.end()) {
missing_mandatory_attributes.insert(name);
}
}
if (!missing_mandatory_attributes.empty()) {
return tensorflow::errors::InvalidArgument(
"Missing mandatory attributes: ",
tensorflow::str_util::Join(missing_mandatory_attributes, " "));
}
return tensorflow::Status::OK();
}
void Attributes::Register(const string &name, Attribute *attribute) {
const bool unique = attributes_.emplace(name, attribute).second;
DCHECK(unique) << "Duplicate attribute '" << name << "'";
}
tensorflow::Status Attributes::ParseValue(const string &str, string *value) {
*value = str;
return tensorflow::Status::OK();
}
tensorflow::Status Attributes::ParseValue(const string &str, bool *value) {
const string lowercased_str = tensorflow::str_util::Lowercase(str);
if (lowercased_str != "true" && lowercased_str != "false") {
return tensorflow::errors::InvalidArgument(
"Attribute can't be parsed as bool: ", str);
}
*value = lowercased_str == "true";
return tensorflow::Status::OK();
}
tensorflow::Status Attributes::ParseValue(const string &str, int32 *value) {
if (!tensorflow::strings::safe_strto32(str, value)) {
return tensorflow::errors::InvalidArgument(
"Attribute can't be parsed as int32: ", str);
}
return tensorflow::Status::OK();
}
tensorflow::Status Attributes::ParseValue(const string &str, int64 *value) {
if (!tensorflow::strings::safe_strto64(str, value)) {
return tensorflow::errors::InvalidArgument(
"Attribute can't be parsed as int64: ", str);
}
return tensorflow::Status::OK();
}
tensorflow::Status Attributes::ParseValue(const string &str, size_t *value) {
int64 signed_value = 0;
if (!tensorflow::strings::safe_strto64(str, &signed_value) ||
signed_value < 0) {
return tensorflow::errors::InvalidArgument(
"Attribute can't be parsed as size_t: ", str);
}
*value = signed_value;
return tensorflow::Status::OK();
}
tensorflow::Status Attributes::ParseValue(const string &str, float *value) {
if (!tensorflow::strings::safe_strtof(str.c_str(), value)) {
return tensorflow::errors::InvalidArgument(
"Attribute can't be parsed as float: ", str);
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for parsing configuration attributes from (name,value) string pairs as
// typed values. Intended for parsing RegisteredModuleSpec.parameters, similar
// to get_attrs_with_defaults() in network_units.py. Example usage:
//
// // Create a subclass of Attributes.
// struct MyComponentAttributes : public Attributes {
// // Mandatory attribute with type and name. The "this" allows the attribute
// // to register itself in its container---i.e., MyComponentAttributes.
// Mandatory<float> coefficient{"coefficient", this};
//
// // Optional attributes with type, name, and default value.
// Optional<bool> ignore_case{"ignore_case", true, this};
// Optional<std::vector<int32>> layer_sizes{"layer_sizes", {1, 2, 3}, this};
//
// // Ignored attribute, which does not parse any value.
// Ignored dropout_keep_prob{"dropout_keep_prob", this};
// };
//
// // Initialize an instance of the subclass from a string-to-string mapping.
// RegisteredModuleSpec spec;
// MyComponentAttributes attributes;
// TF_RETURN_IF_ERROR(attributes.Reset(spec.parameters()));
//
// // Access the attributes as accessors.
// bool ignore_case = attributes.ignore_case();
// float coefficient = attributes.coefficient();
// const std::vector<int32> &layer_sizes = attributes.layer_sizes();
//
// See the unit test for additional usage examples.
//
// TODO(googleuser): Build typed attributes into the RegisteredModuleSpec and
// get rid of this module.
#ifndef DRAGNN_RUNTIME_ATTRIBUTES_H_
#define DRAGNN_RUNTIME_ATTRIBUTES_H_
#include <functional>
#include <map>
#include <string>
#include <vector>
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Base class for sets of attributes. Use as indicated in the file comment.
class Attributes {
public:
// Untyped mapping from which typed attributes are parsed.
using Mapping = tensorflow::protobuf::Map<string, string>;
// Forbids copying, which would invalidate the pointers in |attributes_|.
Attributes(const Attributes &that) = delete;
Attributes &operator=(const Attributes &that) = delete;
// Parses registered attributes from the name-to-value |mapping|. On error,
// returns non-OK. Errors include unknown names in |mapping|, string-to-value
// parsing failures, and missing mandatory attributes.
tensorflow::Status Reset(const Mapping &mapping);
protected:
// Implementations of the supported kinds of attributes, defined below.
class Ignored;
template <class T>
class Optional;
template <class T>
class Mandatory;
// Forbids lifecycle management except via subclasses.
Attributes() = default;
virtual ~Attributes() = default;
private:
// Base class for an individual attribute, defined below.
class Attribute;
// Registers the |attribute| with the |name|, which must be unique.
void Register(const string &name, Attribute *attribute);
// Parses the string |str| into the |value| object.
static tensorflow::Status ParseValue(const string &str, string *value);
static tensorflow::Status ParseValue(const string &str, bool *value);
static tensorflow::Status ParseValue(const string &str, int32 *value);
static tensorflow::Status ParseValue(const string &str, int64 *value);
static tensorflow::Status ParseValue(const string &str, size_t *value);
static tensorflow::Status ParseValue(const string &str, float *value);
template <class Element>
static tensorflow::Status ParseValue(const string &str,
std::vector<Element> *value);
// Registered attributes, keyed by name.
std::map<string, Attribute *> attributes_;
};
// Implementation details below.
// Base class for individual attributes.
class Attributes::Attribute {
public:
Attribute() = default;
Attribute(const Attribute &that) = delete;
Attribute &operator=(const Attribute &that) = delete;
virtual ~Attribute() = default;
// Parses the |value| string into a typed object. On error, returns non-OK.
virtual tensorflow::Status Parse(const string &value) = 0;
// Returns true if this is a mandatory attribute. Defaults to optional.
virtual bool IsMandatory() const { return false; }
};
// Implements an ignored attribute.
class Attributes::Ignored : public Attribute {
public:
// Registers this in the |attributes| with the |name|.
Ignored(const string &name, Attributes *attributes) {
attributes->Register(name, this);
}
// Ignores the |value|.
tensorflow::Status Parse(const string &value) override {
return tensorflow::Status::OK();
}
};
// Implements an optional attribute.
template <class T>
class Attributes::Optional : public Attribute {
public:
// Registers this in the |attributes| with the |name| and |default_value|.
Optional(const string &name, const T &default_value, Attributes *attributes)
: value_(default_value) {
attributes->Register(name, this);
}
// Parses the |value| into the |value_|.
tensorflow::Status Parse(const string &value) override {
return ParseValue(value, &value_);
}
// Returns the parsed |value_|. Overloading operator() allows a struct member
// to be called like an accessor.
const T &operator()() const { return value_; }
private:
// The parsed value, or the default value if not explicitly specified.
T value_;
};
// Implements a mandatory attribute.
template <class T>
class Attributes::Mandatory : public Optional<T> {
public:
// Registers this in the |attributes| with the |name|.
Mandatory(const string &name, Attributes *attributes)
: Optional<T>(name, T(), attributes) {}
// Returns true since this is mandatory.
bool IsMandatory() const override { return true; }
private:
// The parsed value, or the default value if not explicitly specified.
T value_;
};
template <class Element>
tensorflow::Status Attributes::ParseValue(const string &str,
std::vector<Element> *value) {
value->clear();
if (!str.empty()) {
for (const string &element_str : tensorflow::str_util::Split(str, ",")) {
value->emplace_back();
TF_RETURN_IF_ERROR(ParseValue(element_str, &value->back()));
}
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_ATTRIBUTES_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/attributes.h"
#include <map>
#include <set>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns the attribute mapping equivalent of the |std_map|.
Attributes::Mapping MakeMapping(const std::map<string, string> &std_map) {
Attributes::Mapping mapping;
for (const auto &it : std_map) mapping[it.first] = it.second;
return mapping;
}
// Returns a mapping with all attributes explicitly set.
Attributes::Mapping GetFullySpecifiedMapping() {
return MakeMapping({{"some_string", "explicit"},
{"some_bool", "true"},
{"some_int32", "987"},
{"some_int64", "654321"},
{"some_size_t", "7777777"},
{"some_float", "0.25"},
{"some_intvec", "2,3,5,7,11,13"},
{"some_strvec", "a,bc,def"}});
}
// A set of optional attributes.
struct OptionalAttributes : public Attributes {
Optional<string> some_string{"some_string", "default", this};
Optional<bool> some_bool{"some_bool", false, this};
Optional<int32> some_int32{"some_int32", 32, this};
Optional<int64> some_int64{"some_int64", 64, this};
Optional<size_t> some_size_t{"some_size_t", 999, this};
Optional<float> some_float{"some_float", -1.5, this};
Optional<std::vector<int32>> some_intvec{"some_intvec", {}, this};
Optional<std::vector<string>> some_strvec{"some_strvec", {"x", "y"}, this};
};
// Tests that attributes take their default values when they are not explicitly
// specified.
TEST(OptionalAttributesTest, Defaulted) {
Attributes::Mapping mapping;
OptionalAttributes attributes;
TF_ASSERT_OK(attributes.Reset(mapping));
EXPECT_EQ(attributes.some_string(), "default");
EXPECT_FALSE(attributes.some_bool());
EXPECT_EQ(attributes.some_int32(), 32);
EXPECT_EQ(attributes.some_int64(), 64);
EXPECT_EQ(attributes.some_size_t(), 999);
EXPECT_EQ(attributes.some_float(), -1.5);
EXPECT_EQ(attributes.some_intvec(), std::vector<int32>());
EXPECT_EQ(attributes.some_strvec(), std::vector<string>({"x", "y"}));
}
// Tests that attributes can be overridden to explicitly-specified values.
TEST(OptionalAttributesTest, FullySpecified) {
OptionalAttributes attributes;
TF_ASSERT_OK(attributes.Reset(GetFullySpecifiedMapping()));
EXPECT_EQ(attributes.some_string(), "explicit");
EXPECT_TRUE(attributes.some_bool());
EXPECT_EQ(attributes.some_int32(), 987);
EXPECT_EQ(attributes.some_int64(), 654321);
EXPECT_EQ(attributes.some_size_t(), 7777777);
EXPECT_EQ(attributes.some_float(), 0.25);
EXPECT_EQ(attributes.some_intvec(), std::vector<int32>({2, 3, 5, 7, 11, 13}));
EXPECT_EQ(attributes.some_strvec(), std::vector<string>({"a", "bc", "def"}));
}
// Tests that attribute parsing fails for an unknown name.
TEST(OptionalAttributesTest, UnknownName) {
const Attributes::Mapping mapping = MakeMapping({{"unknown", "##BAD##"}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Unknown attribute"));
}
// Tests that attribute parsing fails for malformed bool values.
TEST(OptionalAttributesTest, BadBool) {
for (const string &value :
{" true", "true ", "tr ue", "arst", "1", "t", "y", "yes", " false",
"false ", "fa lse", "oien", "0", "f", "n", "no"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_bool", value}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Attribute can't be parsed as bool"));
}
}
// Tests that attribute parsing works for well-formed bool values.
TEST(OptionalAttributesTest, GoodBool) {
for (const string &value : {"true", "TRUE", "True", "tRuE"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_bool", value}});
OptionalAttributes attributes;
TF_ASSERT_OK(attributes.Reset(mapping));
EXPECT_TRUE(attributes.some_bool());
}
for (const string &value : {"false", "FALSE", "False", "fAlSe"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_bool", value}});
OptionalAttributes attributes;
TF_ASSERT_OK(attributes.Reset(mapping));
EXPECT_FALSE(attributes.some_bool());
}
}
// Tests that attribute parsing fails for malformed int32 values.
TEST(OptionalAttributesTest, BadInt32) {
for (const string &value : {"hello", "true", "1.0", "inf", "nan"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_int32", value}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Attribute can't be parsed as int32"));
}
}
// Tests that attribute parsing fails for malformed int64 values.
TEST(OptionalAttributesTest, BadInt64) {
for (const string &value : {"hello", "true", "1.0", "inf", "nan"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_int64", value}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Attribute can't be parsed as int64"));
}
}
// Tests that attribute parsing fails for malformed size_t values.
TEST(OptionalAttributesTest, BadSizeT) {
for (const string &value :
{"hello", "true", "1.0", "inf", "nan", "-1.0", "-123"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_size_t", value}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Attribute can't be parsed as size_t"));
}
}
// Tests that attribute parsing fails for malformed floats.
TEST(OptionalAttributesTest, BadFloat) {
for (const string &value : {"hello", "true"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_float", value}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Attribute can't be parsed as float"));
}
}
// Tests that attribute parsing fails for malformed std::vector<int32> values.
TEST(OptionalAttributesTest, BadIntVector) {
for (const string &value :
{"hello", "true", "1.0", "inf", "nan", "true,false", "foo,bar,baz"}) {
const Attributes::Mapping mapping = MakeMapping({{"some_intvec", value}});
OptionalAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Attribute can't be parsed as int32"));
}
}
// A set of mandatory attributes.
struct MandatoryAttributes : public Attributes {
Mandatory<string> some_string{"some_string", this};
Mandatory<bool> some_bool{"some_bool", this};
Mandatory<int32> some_int32{"some_int32", this};
Mandatory<int64> some_int64{"some_int64", this};
Mandatory<size_t> some_size_t{"some_size_t", this};
Mandatory<float> some_float{"some_float", this};
Mandatory<std::vector<int32>> some_intvec{"some_intvec", this};
Mandatory<std::vector<string>> some_strvec{"some_strvec", this};
};
// Tests that attribute parsing works when all mandatory attributes are
// explicitly specified.
TEST(MandatoryAttributesTest, FullySpecified) {
MandatoryAttributes attributes;
TF_ASSERT_OK(attributes.Reset(GetFullySpecifiedMapping()));
EXPECT_EQ(attributes.some_string(), "explicit");
EXPECT_TRUE(attributes.some_bool());
EXPECT_EQ(attributes.some_int32(), 987);
EXPECT_EQ(attributes.some_int64(), 654321);
EXPECT_EQ(attributes.some_size_t(), 7777777);
EXPECT_EQ(attributes.some_float(), 0.25);
EXPECT_EQ(attributes.some_intvec(), std::vector<int32>({2, 3, 5, 7, 11, 13}));
EXPECT_EQ(attributes.some_strvec(), std::vector<string>({"a", "bc", "def"}));
}
// Tests that attribute parsing fails when even one mandatory attribute is not
// explicitly specified.
TEST(MandatoryAttributesTest, MissingAttribute) {
for (const auto &it : GetFullySpecifiedMapping()) {
const string &name = it.first;
Attributes::Mapping mapping = GetFullySpecifiedMapping();
CHECK_EQ(mapping.erase(name), 1);
MandatoryAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Missing mandatory attributes"));
}
}
// A set of ignored attributes.
struct IgnoredAttributes : public Attributes {
Ignored foo{"foo", this};
Ignored bar{"bar", this};
Ignored baz{"baz", this};
};
// Tests that ignored attributes are not mandatory.
TEST(IgnoredAttributesTest, NotMandatory) {
const Attributes::Mapping mapping;
IgnoredAttributes attributes;
TF_ASSERT_OK(attributes.Reset(mapping));
}
// Tests that attribute parsing consumes ignored names.
TEST(IgnoredAttributesTest, IgnoredName) {
const Attributes::Mapping mapping =
MakeMapping({{"foo", "blah"}, {"bar", "123"}, {"baz", " "}});
IgnoredAttributes attributes;
TF_ASSERT_OK(attributes.Reset(mapping));
}
// Tests that attribute parsing still fails for unknown names.
TEST(IgnoredAttributesTest, UnknownName) {
const Attributes::Mapping mapping = MakeMapping(
{{"foo", "blah"}, {"bar", "123"}, {"baz", " "}, {"unknown", ""}});
IgnoredAttributes attributes;
EXPECT_THAT(attributes.Reset(mapping),
test::IsErrorWithSubstr("Unknown attribute"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/eigen.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Produces pairwise activations via a biaffine product between source and
// target token activations, as in the Dozat parser. This is the runtime
// version of the BiaffineDigraphNetwork, but is implemented as a Component
// instead of a NetworkUnit so it can control operand allocation.
class BiaffineDigraphComponent : public Component {
public:
// Implements Component.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override;
bool Supports(const ComponentSpec &component_spec,
const string &normalized_builder_name) const override;
bool PreferredTo(const Component &other) const override { return false; }
private:
// Weights for computing source-target arc potentials.
Matrix<float> arc_weights_;
// Weights for computing source-selection potentials.
Vector<float> source_weights_;
// Weights and bias for root-target arc potentials.
Vector<float> root_weights_;
float root_bias_ = 0.0;
// Source and target token activation inputs.
LayerHandle<float> sources_handle_;
LayerHandle<float> targets_handle_;
// Directed adjacency matrix output.
PairwiseLayerHandle<float> adjacency_handle_;
// Handles for intermediate computations.
LocalMatrixHandle<float> target_product_handle_;
};
bool BiaffineDigraphComponent::Supports(
const ComponentSpec &component_spec,
const string &normalized_builder_name) const {
const string network_unit = NetworkUnit::GetClassName(component_spec);
return (normalized_builder_name == "BulkFeatureExtractorComponent" ||
normalized_builder_name == "BiaffineDigraphComponent") &&
network_unit == "BiaffineDigraphNetwork";
}
// Finds the link named |name| in the |component_spec| and points the |handle|
// at the corresponding layer in the |network_state_manager|. The layer must
// also match the |required_dimension|. Returns non-OK on error.
tensorflow::Status FindAndValidateLink(
const ComponentSpec &component_spec,
const NetworkStateManager &network_state_manager, const string &name,
size_t required_dimension, LayerHandle<float> *handle) {
const LinkedFeatureChannel *link = nullptr;
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.name() == name) {
link = &channel;
break;
}
}
if (link == nullptr) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": link '", name, "' does not exist");
}
const string error_suffix = tensorflow::strings::StrCat(
" in link { ", link->ShortDebugString(), " }");
if (link->embedding_dim() != -1) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": transformed links are forbidden",
error_suffix);
}
if (link->size() != 1) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": multi-embedding links are forbidden",
error_suffix);
}
if (link->source_component() == component_spec.name()) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": recurrent links are forbidden", error_suffix);
}
if (link->fml() != "input.focus" || link->source_translator() != "identity") {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": non-trivial link translation is forbidden",
error_suffix);
}
size_t dimension = 0;
TF_RETURN_IF_ERROR(network_state_manager.LookupLayer(
link->source_component(), link->source_layer(), &dimension, handle));
if (dimension != required_dimension) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": link '", name, "' has dimension ", dimension,
" instead of ", required_dimension, error_suffix);
}
return tensorflow::Status::OK();
}
tensorflow::Status BiaffineDigraphComponent::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
TF_RETURN_IF_ERROR(variable_store->Lookup(
tensorflow::strings::StrCat(component_spec.name(), "/weights_arc"),
&arc_weights_));
const size_t source_dimension = arc_weights_.num_rows();
const size_t target_dimension = arc_weights_.num_columns();
TF_RETURN_IF_ERROR(variable_store->Lookup(
tensorflow::strings::StrCat(component_spec.name(), "/weights_source"),
&source_weights_));
if (source_weights_.size() != source_dimension) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": dimension mismatch between weights_arc [",
source_dimension, ",", target_dimension, "] and weights_source [",
source_weights_.size(), "]");
}
TF_RETURN_IF_ERROR(variable_store->Lookup(
tensorflow::strings::StrCat(component_spec.name(), "/root_weights"),
&root_weights_));
if (root_weights_.size() != target_dimension) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": dimension mismatch between weights_arc [",
source_dimension, ",", target_dimension, "] and root_weights [",
root_weights_.size(), "]");
}
Vector<float> root_bias;
TF_RETURN_IF_ERROR(variable_store->Lookup(
tensorflow::strings::StrCat(component_spec.name(), "/root_bias"),
&root_bias));
if (root_bias.size() != 1) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": root_bias must be a singleton");
}
root_bias_ = root_bias[0];
if (component_spec.fixed_feature_size() != 0) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": fixed features are forbidden");
}
if (component_spec.linked_feature_size() != 2) {
return tensorflow::errors::InvalidArgument(
component_spec.name(), ": two linked features are required");
}
TF_RETURN_IF_ERROR(FindAndValidateLink(component_spec, *network_state_manager,
"sources", source_dimension,
&sources_handle_));
TF_RETURN_IF_ERROR(FindAndValidateLink(component_spec, *network_state_manager,
"targets", target_dimension,
&targets_handle_));
TF_RETURN_IF_ERROR(
network_state_manager->AddLayer("adjacency", 1, &adjacency_handle_));
TF_RETURN_IF_ERROR(network_state_manager->AddLocal(source_dimension,
&target_product_handle_));
return tensorflow::Status::OK();
}
tensorflow::Status BiaffineDigraphComponent::Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const {
NetworkStates &network_states = session_state->network_states;
// Infer the number of steps from the source and target activations.
EigenMatrixMap<float> sources =
AsEigenMap(Matrix<float>(network_states.GetLayer(sources_handle_)));
EigenMatrixMap<float> targets =
AsEigenMap(Matrix<float>(network_states.GetLayer(targets_handle_)));
const size_t num_steps = sources.rows();
if (targets.rows() != num_steps) {
return tensorflow::errors::InvalidArgument(
"step count mismatch between sources (", num_steps, ") and targets (",
targets.rows(), ")");
}
// Since this component has a pairwise layer, allocate steps in one shot.
network_states.AddSteps(num_steps);
MutableEigenMatrixMap<float> adjacency =
AsEigenMap(network_states.GetLayer(adjacency_handle_));
MutableEigenMatrixMap<float> target_product =
AsEigenMap(network_states.GetLocal(target_product_handle_));
// First compute the adjacency matrix of combined arc and source scores.
// Note: .noalias() ensures that the RHS is assigned directly to the LHS;
// otherwise, Eigen may allocate a temp matrix to hold the result of the
// matmul on the RHS and then copy that to the LHS. See
// http://eigen.tuxfamily.org/dox/TopicLazyEvaluation.html
target_product.noalias() = targets * AsEigenMap(arc_weights_).transpose();
target_product.rowwise() += AsEigenMap(source_weights_);
adjacency.noalias() = target_product * sources.transpose();
// Now overwrite the diagonal with root-selection scores.
// Note: .array() allows the scalar addition of |root_bias_| to broadcast
// across the diagonal. See
// https://eigen.tuxfamily.org/dox/group__TutorialArrayClass.html
adjacency.diagonal().noalias() =
AsEigenMap(root_weights_) * targets.transpose();
adjacency.diagonal().array() += root_bias_;
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT(BiaffineDigraphComponent);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::Return;
constexpr size_t kNumSteps = 33;
constexpr size_t kSourceDim = 44;
constexpr size_t kTargetDim = 55;
constexpr size_t kBadDim = 11;
constexpr float kArcWeight = 1.0;
constexpr float kSourceWeight = 2.0;
constexpr float kRootWeight = 4.0;
constexpr float kRootBias = 8.0;
constexpr float kSourceValue = -0.5;
constexpr float kTargetValue = 1.5;
constexpr char kSourcesComponentName[] = "sources";
constexpr char kTargetsComponentName[] = "targets";
constexpr char kSourcesLayerName[] = "sources";
constexpr char kTargetsLayerName[] = "targets";
constexpr char kBadDimLayerName[] = "bad";
// Configuration for the Run() method. This makes it easier for tests to
// manipulate breakages.
struct RunConfig {
// Number of steps in the preceding components.
size_t sources_num_steps = kNumSteps;
size_t targets_num_steps = kNumSteps;
// Dimensions of the variables.
size_t weights_source_dim = kSourceDim;
size_t root_weights_dim = kTargetDim;
size_t root_bias_dim = 1;
};
class BiaffineDigraphComponentTest : public NetworkTestBase {
protected:
BiaffineDigraphComponentTest() {
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input_));
}
// Returns a working spec.
static ComponentSpec MakeGoodSpec() {
ComponentSpec component_spec;
component_spec.set_name(kTestComponentName);
component_spec.mutable_component_builder()->set_registered_name(
"bulk_component.BulkFeatureExtractorComponentBuilder");
component_spec.mutable_network_unit()->set_registered_name(
"biaffine_units.BiaffineDigraphNetwork");
for (const string &name : {kSourcesLayerName, kTargetsLayerName}) {
LinkedFeatureChannel *link = component_spec.add_linked_feature();
link->set_name(name);
link->set_embedding_dim(-1);
link->set_size(1);
link->set_source_component(name);
link->set_source_layer(name);
link->set_source_translator("identity");
link->set_fml("input.focus");
}
return component_spec;
}
// Creates a component, initializes it based on the |component_spec|, and
// evaluates it. On error, returns non-OK.
tensorflow::Status Run(const ComponentSpec &component_spec,
const RunConfig &config = RunConfig()) {
AddComponent(kSourcesComponentName);
AddLayer(kSourcesLayerName, kSourceDim);
AddComponent(kTargetsComponentName);
AddLayer(kTargetsLayerName, kTargetDim);
AddLayer(kBadDimLayerName, kBadDim);
AddComponent(kTestComponentName);
AddMatrixVariable(
tensorflow::strings::StrCat(kTestComponentName, "/weights_arc"),
kSourceDim, kTargetDim, kArcWeight);
AddVectorVariable(
tensorflow::strings::StrCat(kTestComponentName, "/weights_source"),
config.weights_source_dim, kSourceWeight);
AddVectorVariable(
tensorflow::strings::StrCat(kTestComponentName, "/root_weights"),
config.root_weights_dim, kRootWeight);
AddVectorVariable(
tensorflow::strings::StrCat(kTestComponentName, "/root_bias"),
config.root_bias_dim, kRootBias);
TF_RETURN_IF_ERROR(
Component::CreateOrError("BiaffineDigraphComponent", &component_));
TF_RETURN_IF_ERROR(component_->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
network_states_.Reset(&network_state_manager_);
StartComponent(config.sources_num_steps);
FillLayer(kSourcesComponentName, kSourcesLayerName, kSourceValue);
StartComponent(config.targets_num_steps);
FillLayer(kTargetsComponentName, kTargetsLayerName, kTargetValue);
StartComponent(0); // BiaffineDigraphComponent will add steps
session_state_.extensions.Reset(&extension_manager_);
TF_RETURN_IF_ERROR(
component_->Evaluate(&session_state_, &compute_session_, nullptr));
adjacency_ = GetPairwiseLayer(kTestComponentName, "adjacency");
return tensorflow::Status::OK();
}
InputBatchCache input_;
std::unique_ptr<Component> component_;
Matrix<float> adjacency_;
};
// Tests that the good spec works properly.
TEST_F(BiaffineDigraphComponentTest, GoodSpec) {
TF_ASSERT_OK(Run(MakeGoodSpec()));
constexpr float kExpectedRootScore =
kRootWeight * kTargetValue * kTargetDim + kRootBias;
constexpr float kExpectedArcScore =
kSourceDim * kSourceValue * kArcWeight * kTargetValue * kTargetDim +
kSourceWeight * kSourceValue * kSourceDim;
ASSERT_EQ(adjacency_.num_rows(), kNumSteps);
ASSERT_EQ(adjacency_.num_columns(), kNumSteps);
for (size_t row = 0; row < kNumSteps; ++row) {
for (size_t column = 0; column < kNumSteps; ++column) {
if (row == column) {
ASSERT_EQ(adjacency_.row(row)[column], kExpectedRootScore);
} else {
ASSERT_EQ(adjacency_.row(row)[column], kExpectedArcScore);
}
}
}
}
// Tests the set of supported components.
TEST_F(BiaffineDigraphComponentTest, Supports) {
ComponentSpec component_spec = MakeGoodSpec();
string component_name;
TF_ASSERT_OK(Component::Select(component_spec, &component_name));
EXPECT_EQ(component_name, "BiaffineDigraphComponent");
component_spec.mutable_network_unit()->set_registered_name("bad");
EXPECT_THAT(Component::Select(component_spec, &component_name),
test::IsErrorWithSubstr("Could not find a best spec"));
component_spec = MakeGoodSpec();
component_spec.mutable_component_builder()->set_registered_name(
"BiaffineDigraphComponent");
TF_ASSERT_OK(Component::Select(component_spec, &component_name));
EXPECT_EQ(component_name, "BiaffineDigraphComponent");
component_spec.mutable_component_builder()->set_registered_name("bad");
EXPECT_THAT(Component::Select(component_spec, &component_name),
test::IsErrorWithSubstr("Could not find a best spec"));
}
// Tests that fixed features are rejected.
TEST_F(BiaffineDigraphComponentTest, FixedFeatures) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.add_fixed_feature();
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("fixed features are forbidden"));
}
// Tests that too few linked features are rejected.
TEST_F(BiaffineDigraphComponentTest, TooFewLinkedFeatures) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature()->RemoveLast();
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("two linked features are required"));
}
// Tests that too many linked features are rejected.
TEST_F(BiaffineDigraphComponentTest, TooManyLinkedFeatures) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.add_linked_feature();
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("two linked features are required"));
}
// Tests that a spec with no "sources" link is rejected.
TEST_F(BiaffineDigraphComponentTest, MissingSources) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(0)->set_name("bad");
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("link 'sources' does not exist"));
}
// Tests that a spec with no "targets" link is rejected.
TEST_F(BiaffineDigraphComponentTest, MissingTargets) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_name("bad");
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("link 'targets' does not exist"));
}
// Tests that a spec with transformed links is rejected.
TEST_F(BiaffineDigraphComponentTest, TransformedLink) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_embedding_dim(123);
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("transformed links are forbidden"));
}
// Tests that a spec with multi-embedding links is rejected.
TEST_F(BiaffineDigraphComponentTest, MultiEmbeddingLink) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_size(2);
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("multi-embedding links are forbidden"));
}
// Tests that a spec with recurrent links is rejected.
TEST_F(BiaffineDigraphComponentTest, RecurrentLink) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_source_component(
kTestComponentName);
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr("recurrent links are forbidden"));
}
// Tests that a spec with improper FML is rejected.
TEST_F(BiaffineDigraphComponentTest, BadFML) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_fml("bad");
EXPECT_THAT(
Run(component_spec),
test::IsErrorWithSubstr("non-trivial link translation is forbidden"));
}
// Tests that a spec with non-identity links is rejected.
TEST_F(BiaffineDigraphComponentTest, NonIdentityLink) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_source_translator("bad");
EXPECT_THAT(
Run(component_spec),
test::IsErrorWithSubstr("non-trivial link translation is forbidden"));
}
// Tests that a link with the wrong dimension is rejected.
TEST_F(BiaffineDigraphComponentTest, WrongLinkDimension) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_linked_feature(1)->set_source_layer(kBadDimLayerName);
EXPECT_THAT(
Run(component_spec),
test::IsErrorWithSubstr("link 'targets' has dimension 11 instead of 55"));
}
// Tests that a mismatched weights_source dimension is rejected.
TEST_F(BiaffineDigraphComponentTest, WeightsSourceDimensionMismatch) {
RunConfig config;
config.weights_source_dim = 999;
EXPECT_THAT(Run(MakeGoodSpec(), config),
test::IsErrorWithSubstr("dimension mismatch between weights_arc "
"[44,55] and weights_source [999]"));
}
// Tests that a mismatched root_weights dimension is rejected.
TEST_F(BiaffineDigraphComponentTest, RootWeightsDimensionMismatch) {
RunConfig config;
config.root_weights_dim = 999;
EXPECT_THAT(Run(MakeGoodSpec(), config),
test::IsErrorWithSubstr("dimension mismatch between weights_arc "
"[44,55] and root_weights [999]"));
}
// Tests that a mismatched root_bias dimension is rejected.
TEST_F(BiaffineDigraphComponentTest, RootBiasDimensionMismatch) {
RunConfig config;
config.root_bias_dim = 999;
EXPECT_THAT(Run(MakeGoodSpec(), config),
test::IsErrorWithSubstr("root_bias must be a singleton"));
}
// Tests that a mismatched number of steps is rejected.
TEST_F(BiaffineDigraphComponentTest, StepCountMismatch) {
RunConfig config;
config.targets_num_steps = 999;
EXPECT_THAT(
Run(MakeGoodSpec(), config),
test::IsErrorWithSubstr(
"step count mismatch between sources (33) and targets (999)"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/network_unit_base.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Network unit that allows us to make calls to NetworkUnitBase and extract
// features. We may want to provide more optimized versions of this class.
class BulkFeatureExtractorNetwork : public NetworkUnitBase {
public:
// Returns true if this supports the |component_spec|. Requires:
// * A deterministic transition system, which can be advanced from the oracle.
// * No recurrent linked features (i.e. from this system).
static bool Supports(const ComponentSpec &component_spec);
// Implements NetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
// Advances the |compute_session| through all oracle transitions and extracts
// fixed and linked embeddings, concatenates them into an input matrix stored
// in the NetworkStates in the |session_state|, and points the |inputs| at it.
// Also adds steps to the NetworkStates. On error, returns non-OK.
tensorflow::Status EvaluateInputs(SessionState *session_state,
ComputeSession *compute_session,
Matrix<float> *inputs) const;
private:
// Implements NetworkUnit. Evaluate() is "final" to encourage inlining.
string GetLogitsName() const override { return ""; }
tensorflow::Status Evaluate(size_t step_index, SessionState *session_state,
ComputeSession *compute_session) const final;
// Name of the containing component.
string name_;
// Concatenated input matrix.
LocalMatrixHandle<float> inputs_handle_;
};
bool BulkFeatureExtractorNetwork::Supports(
const ComponentSpec &component_spec) {
if (!TransitionSystemTraits(component_spec).is_deterministic) return false;
// Forbid recurrent linked features.
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.source_component() == component_spec.name()) return false;
}
return true;
}
tensorflow::Status BulkFeatureExtractorNetwork::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
name_ = component_spec.name();
if (!Supports(component_spec)) {
return tensorflow::errors::InvalidArgument(
"BulkFeatureExtractorNetwork does not support component '", name_, "'");
}
const bool use_concatenated_input = true;
TF_RETURN_IF_ERROR(InitializeBase(use_concatenated_input, component_spec,
variable_store, network_state_manager,
extension_manager));
return network_state_manager->AddLocal(concatenated_input_dim(),
&inputs_handle_);
}
tensorflow::Status BulkFeatureExtractorNetwork::EvaluateInputs(
SessionState *session_state, ComputeSession *compute_session,
Matrix<float> *inputs) const {
// TODO(googleuser): Try the ComputeSession's bulk feature extraction API?
for (size_t step_idx = 0; !compute_session->IsTerminal(name_); ++step_idx) {
session_state->network_states.AddStep();
TF_RETURN_IF_ERROR(Evaluate(step_idx, session_state, compute_session));
compute_session->AdvanceFromOracle(name_);
}
*inputs = session_state->network_states.GetLocal(inputs_handle_);
return tensorflow::Status::OK();
}
tensorflow::Status BulkFeatureExtractorNetwork::Evaluate(
size_t step_index, SessionState *session_state,
ComputeSession *compute_session) const {
Vector<float> input;
TF_RETURN_IF_ERROR(EvaluateBase(session_state, compute_session, &input));
MutableMatrix<float> all_inputs =
session_state->network_states.GetLocal(inputs_handle_);
// TODO(googleuser): Punch a hole in EvaluateBase so it writes directly to
// all_inputs.row(step_index).
//
// In the future, we could entirely eliminate copying, by providing a variant
// of LstmCellFunction::RunInputComputation that adds a partial vector of
// inputs, e.g. instead of RunInputComputation(x), we compute
//
// RunInputComputation(x[0:32]) + RunInputComputation(x[32:64])
//
// where perhaps x[0:32] points directly at a fixed word feature vector, and
// x[32:64] points directly at the previous layer's outputs (as a linked
// feature).
MutableVector<float> output = all_inputs.row(step_index);
DCHECK_EQ(input.size(), output.size());
// TODO(googleuser): Try memcpy() or a custom vectorized copy.
for (int i = 0; i < input.size(); ++i) {
output[i] = input[i];
}
return tensorflow::Status::OK();
}
// Bulk version of a DynamicComponent---i.e., a component that was originally
// dynamic but can be automatically upgraded to a bulk version.
class BulkDynamicComponent : public Component {
protected:
// Implements Component.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override;
bool Supports(const ComponentSpec &component_spec,
const string &normalized_builder_name) const override;
bool PreferredTo(const Component &other) const override { return true; }
private:
// Feature extractor that builds the input activation matrix.
BulkFeatureExtractorNetwork bulk_feature_extractor_;
// Network unit for bulk computation.
std::unique_ptr<BulkNetworkUnit> bulk_network_unit_;
};
// In addition to the BulkFeatureExtractorNetwork requirements, the bulk LSTM
// requires no attention (the runtime doesn't support attention yet).
bool BulkDynamicComponent::Supports(
const ComponentSpec &component_spec,
const string &normalized_builder_name) const {
return BulkFeatureExtractorNetwork::Supports(component_spec) &&
(normalized_builder_name == "DynamicComponent" ||
normalized_builder_name == "BulkDynamicComponent") &&
component_spec.attention_component().empty();
}
tensorflow::Status BulkDynamicComponent::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
TF_RETURN_IF_ERROR(bulk_feature_extractor_.Initialize(
component_spec, variable_store, network_state_manager,
extension_manager));
TF_RETURN_IF_ERROR(BulkNetworkUnit::CreateOrError(
BulkNetworkUnit::GetClassName(component_spec), &bulk_network_unit_));
TF_RETURN_IF_ERROR(
bulk_network_unit_->Initialize(component_spec, variable_store,
network_state_manager, extension_manager));
return bulk_network_unit_->ValidateInputDimension(
bulk_feature_extractor_.concatenated_input_dim());
}
tensorflow::Status BulkDynamicComponent::Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const {
Matrix<float> inputs;
TF_RETURN_IF_ERROR(bulk_feature_extractor_.EvaluateInputs(
session_state, compute_session, &inputs));
return bulk_network_unit_->Evaluate(inputs, session_state);
}
DRAGNN_RUNTIME_REGISTER_COMPONENT(BulkDynamicComponent);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::_;
using ::testing::Invoke;
using ::testing::Return;
constexpr size_t kNumSteps = 50;
constexpr size_t kFixedDim = 11;
constexpr size_t kFixedVocabularySize = 123;
constexpr float kFixedValue = 0.5;
constexpr size_t kLinkedDim = 13;
constexpr float kLinkedValue = 1.25;
constexpr char kPreviousComponentName[] = "previous_component";
constexpr char kPreviousLayerName[] = "previous_layer";
constexpr char kOutputsName[] = "outputs";
constexpr size_t kOutputsDim = kFixedDim + kLinkedDim;
// Adds one to all inputs.
class BulkAddOne : public BulkNetworkUnit {
public:
// Implements BulkNetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return network_state_manager->AddLayer(kOutputsName, kOutputsDim,
&outputs_handle_);
}
tensorflow::Status ValidateInputDimension(size_t dimension) const override {
return tensorflow::Status::OK();
}
string GetLogitsName() const override { return ""; }
tensorflow::Status Evaluate(Matrix<float> inputs,
SessionState *session_state) const override {
const MutableMatrix<float> outputs =
session_state->network_states.GetLayer(outputs_handle_);
if (outputs.num_rows() != inputs.num_rows() ||
outputs.num_columns() != inputs.num_columns()) {
return tensorflow::errors::InvalidArgument("Dimension mismatch");
}
for (size_t row = 0; row < inputs.num_rows(); ++row) {
for (size_t column = 0; column < inputs.num_columns(); ++column) {
outputs.row(row)[column] = inputs.row(row)[column] + 1.0;
}
}
return tensorflow::Status::OK();
}
private:
// Output outputs.
LayerHandle<float> outputs_handle_;
};
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT(BulkAddOne);
// A component that also prefers itself but is triggered on a certain backend.
// This can be used to cause a component selection conflict.
class ImTheBest : public Component {
public:
// Implements Component.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return tensorflow::Status::OK();
}
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override {
return tensorflow::Status::OK();
}
bool Supports(const ComponentSpec &component_spec,
const string &normalized_builder_name) const override {
return component_spec.backend().registered_name() == "CauseConflict";
}
bool PreferredTo(const Component &other) const override { return true; }
};
DRAGNN_RUNTIME_REGISTER_COMPONENT(ImTheBest);
class BulkDynamicComponentTest : public NetworkTestBase {
protected:
// Returns a spec that the network supports.
ComponentSpec GetSupportedSpec() {
ComponentSpec component_spec;
component_spec.set_name(kTestComponentName);
component_spec.set_num_actions(1);
component_spec.mutable_network_unit()->set_registered_name("AddOne");
component_spec.mutable_component_builder()->set_registered_name(
"DynamicComponent");
FixedFeatureChannel *fixed_feature = component_spec.add_fixed_feature();
fixed_feature->set_size(1);
fixed_feature->set_embedding_dim(kFixedDim);
fixed_feature->set_vocabulary_size(kFixedVocabularySize);
LinkedFeatureChannel *linked_feature = component_spec.add_linked_feature();
linked_feature->set_size(1);
linked_feature->set_embedding_dim(-1);
linked_feature->set_source_component(kPreviousComponentName);
linked_feature->set_source_layer(kPreviousLayerName);
return component_spec;
}
// Adds mock call expectations to the |compute_session_| for the transition
// system traversal and feature extraction.
void AddComputeSessionMocks() {
SetupTransitionLoop(kNumSteps);
EXPECT_CALL(compute_session_, AdvanceFromOracle(_)).Times(kNumSteps);
EXPECT_CALL(compute_session_, GetInputFeatures(_, _, _, _, _))
.Times(kNumSteps)
.WillRepeatedly(Invoke(ExtractFeatures(0, {{0, 1.0}})));
EXPECT_CALL(compute_session_, GetTranslatedLinkFeatures(_, _))
.Times(kNumSteps)
.WillRepeatedly(Invoke(ExtractLinks(0, {"step_idx: 0"})));
EXPECT_CALL(compute_session_, SourceComponentBeamSize(_, _))
.Times(kNumSteps)
.WillRepeatedly(Return(1));
}
// Creates a network unit, initializes it based on the |component_spec_text|,
// and evaluates it. On error, returns non-OK.
tensorflow::Status Run(const ComponentSpec &component_spec) {
AddComponent(kPreviousComponentName);
AddLayer(kPreviousLayerName, kLinkedDim);
AddComponent(kTestComponentName);
AddFixedEmbeddingMatrix(0, kFixedVocabularySize, kFixedDim, kFixedValue);
TF_RETURN_IF_ERROR(
Component::CreateOrError("BulkDynamicComponent", &component_));
TF_RETURN_IF_ERROR(component_->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
// Allocates network states for a few steps.
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
FillLayer(kPreviousComponentName, kPreviousLayerName, kLinkedValue);
StartComponent(0);
session_state_.extensions.Reset(&extension_manager_);
TF_RETURN_IF_ERROR(
component_->Evaluate(&session_state_, &compute_session_, nullptr));
outputs_ = GetLayer(kTestComponentName, kOutputsName);
return tensorflow::Status::OK();
}
std::unique_ptr<Component> component_;
Matrix<float> outputs_;
};
// Tests that the supported spec is supported.
TEST_F(BulkDynamicComponentTest, Supported) {
const ComponentSpec component_spec = GetSupportedSpec();
string component_type;
TF_ASSERT_OK(Component::Select(component_spec, &component_type));
EXPECT_EQ(component_type, "BulkDynamicComponent");
AddComputeSessionMocks();
TF_ASSERT_OK(Run(component_spec));
ASSERT_EQ(outputs_.num_rows(), kNumSteps);
ASSERT_EQ(outputs_.num_columns(), kFixedDim + kLinkedDim);
for (size_t row = 0; row < kNumSteps; ++row) {
size_t column = 0;
for (; column < kFixedDim; ++column) {
EXPECT_EQ(outputs_.row(row)[column], kFixedValue + 1.0);
}
for (; column < kFixedDim + kLinkedDim; ++column) {
EXPECT_EQ(outputs_.row(row)[column], kLinkedValue + 1.0);
}
}
}
// Tests that the BulkDynamicComponent also supports its own name.
TEST_F(BulkDynamicComponentTest, SupportsBulkName) {
ComponentSpec component_spec = GetSupportedSpec();
component_spec.mutable_component_builder()->set_registered_name(
"BulkDynamicComponent");
string component_type;
TF_ASSERT_OK(Component::Select(component_spec, &component_type));
EXPECT_EQ(component_type, "BulkDynamicComponent");
}
// Tests that the transition system must be deterministic.
TEST_F(BulkDynamicComponentTest, ForbidNonDeterminism) {
ComponentSpec component_spec = GetSupportedSpec();
component_spec.set_num_actions(100);
string component_type;
EXPECT_THAT(
Component::Select(component_spec, &component_type),
test::IsErrorWithSubstr("Could not find a best spec for component"));
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr(
"BulkFeatureExtractorNetwork does not support component"));
}
// Tests that links cannot be recurrent.
TEST_F(BulkDynamicComponentTest, ForbidRecurrences) {
ComponentSpec component_spec = GetSupportedSpec();
component_spec.mutable_linked_feature(0)->set_source_component(
kTestComponentName);
string component_type;
EXPECT_THAT(
Component::Select(component_spec, &component_type),
test::IsErrorWithSubstr("Could not find a best spec for component"));
EXPECT_THAT(Run(component_spec),
test::IsErrorWithSubstr(
"BulkFeatureExtractorNetwork does not support component"));
}
// Tests that the component prefers itself.
TEST_F(BulkDynamicComponentTest, PrefersItself) {
ComponentSpec component_spec = GetSupportedSpec();
component_spec.mutable_backend()->set_registered_name("CauseConflict");
// The "CauseConflict" backend triggers the ImTheBest component, which also
// prefers itself and leads to a selection conflict.
string component_type;
EXPECT_THAT(Component::Select(component_spec, &component_type),
test::IsErrorWithSubstr("both think they should be preferred"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/feed_forward_network_kernel.h"
#include "dragnn/runtime/feed_forward_network_layer.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// A network unit that evaluates a feed-forward multi-layer perceptron.
class BulkFeedForwardNetwork : public BulkNetworkUnit {
public:
// Implements BulkNetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override;
tensorflow::Status ValidateInputDimension(size_t dimension) const override;
string GetLogitsName() const override { return kernel_.logits_name(); }
tensorflow::Status Evaluate(Matrix<float> inputs,
SessionState *session_state) const override;
private:
// Kernel that implements the feed-forward network.
FeedForwardNetworkKernel kernel_;
};
tensorflow::Status BulkFeedForwardNetwork::Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) {
for (const LinkedFeatureChannel &channel : component_spec.linked_feature()) {
if (channel.source_component() == component_spec.name()) {
return tensorflow::errors::InvalidArgument(
"BulkFeedForwardNetwork forbids recurrent links");
}
}
return kernel_.Initialize(component_spec, variable_store,
network_state_manager);
}
tensorflow::Status BulkFeedForwardNetwork::ValidateInputDimension(
size_t dimension) const {
return kernel_.ValidateInputDimension(dimension);
}
tensorflow::Status BulkFeedForwardNetwork::Evaluate(
Matrix<float> inputs, SessionState *session_state) const {
for (const FeedForwardNetworkLayer &layer : kernel_.layers()) {
inputs = layer.Apply(inputs, session_state->network_states);
}
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT(BulkFeedForwardNetwork);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <algorithm>
#include <memory>
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/flexible_matrix_kernel.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr size_t kInputDim = 5;
constexpr size_t kLogitsDim = 3;
constexpr size_t kNumSteps = 4;
constexpr float kEmbedding = 1.25;
// Applies the ReLU activation to the |value|.
float Relu(float value) { return std::max(0.0f, value); }
class BulkFeedForwardNetworkTest : public NetworkTestBase {
protected:
// Adds a weight matrix with the |name_suffix| with the given dimensions and
// |fill_value|.
void AddWeights(const string &name_suffix, size_t num_rows,
size_t num_columns, float fill_value) {
const string weights_name =
tensorflow::strings::StrCat(kTestComponentName, "/weights_",
name_suffix, FlexibleMatrixKernel::kSuffix);
AddMatrixVariable(weights_name, num_columns, num_rows, fill_value);
}
// Adds a bias vector with the |name_suffix| with the given dimensions and
// |fill_value|.
void AddBiases(const string &name_suffix, size_t dimension,
float fill_value) {
const string biases_name =
tensorflow::strings::StrCat(kTestComponentName, "/bias_", name_suffix);
AddVectorVariable(biases_name, dimension, fill_value);
}
// Creates a network unit, initializes it based on the |component_spec_text|,
// and evaluates it. On error, returns non-OK.
tensorflow::Status Run(const string &component_spec_text) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
component_spec.set_name(kTestComponentName);
AddComponent(kTestComponentName);
TF_CHECK_OK(BulkNetworkUnit::CreateOrError("BulkFeedForwardNetwork",
&bulk_network_unit_));
TF_RETURN_IF_ERROR(bulk_network_unit_->Initialize(
component_spec, &variable_store_, &network_state_manager_,
&extension_manager_));
size_t input_dimension = 0;
for (const FixedFeatureChannel &channel : component_spec.fixed_feature()) {
input_dimension += channel.embedding_dim();
}
TF_RETURN_IF_ERROR(
bulk_network_unit_->ValidateInputDimension(input_dimension));
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
session_state_.extensions.Reset(&extension_manager_);
const std::vector<float> row(kInputDim, kEmbedding);
UniqueMatrix<float> input(std::vector<std::vector<float>>(kNumSteps, row));
return bulk_network_unit_->Evaluate(Matrix<float>(*input), &session_state_);
}
// Returns the layer named |layer_name| in the current component.
Matrix<float> GetActivations(const string &layer_name) const {
return Matrix<float>(GetLayer(kTestComponentName, layer_name));
}
std::unique_ptr<BulkNetworkUnit> bulk_network_unit_;
};
// Tests that BulkFeedForwardNetwork fails when a weight matrix does not match
// the dimension of its output activations.
TEST_F(BulkFeedForwardNetworkTest, BadWeightRows) {
const string kBadSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
num_actions: 3)";
AddWeights("softmax", kInputDim, kLogitsDim - 1 /* bad */, 1.0);
AddBiases("softmax", kLogitsDim, 1.0);
EXPECT_THAT(
Run(kBadSpec),
test::IsErrorWithSubstr(
"Weight matrix shape should be output dimension plus padding"));
}
// Tests that BulkFeedForwardNetwork fails when a weight matrix does not match
// the dimension of its input activations.
TEST_F(BulkFeedForwardNetworkTest, BadWeightColumns) {
const string kBadSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
num_actions: 3)";
AddWeights("softmax", kInputDim + 1 /* bad */, kLogitsDim, 1.0);
AddBiases("softmax", kLogitsDim, 1.0);
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr(
"Weight matrix shape does not match input dimension"));
}
// Tests that BulkFeedForwardNetwork fails when a bias vector does not match the
// dimension of its output activations.
TEST_F(BulkFeedForwardNetworkTest, BadBiasDimension) {
const string kBadSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
num_actions: 3)";
AddWeights("softmax", kInputDim, kLogitsDim, 1.0);
AddBiases("softmax", kLogitsDim + 1 /* bad */, 1.0);
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr(
"Bias vector shape does not match output dimension"));
}
// Tests that BulkFeedForwardNetwork fails when the value of the
// "layer_norm_input" option is not false.
TEST_F(BulkFeedForwardNetworkTest, UnsupportedLayerNormInputOption) {
const string kBadSpec = R"(network_unit {
parameters {
key: 'layer_norm_input'
value: 'true'
}
})";
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr("Layer norm is not supported"));
}
// Tests that BulkFeedForwardNetwork fails when the value of the
// "layer_norm_hidden" option is not false.
TEST_F(BulkFeedForwardNetworkTest, UnsupportedLayerNormHiddenOption) {
const string kBadSpec = R"(network_unit {
parameters {
key: 'layer_norm_hidden'
value: 'true'
}
})";
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr("Layer norm is not supported"));
}
// Tests that BulkFeedForwardNetwork fails when the value of the "nonlinearity"
// option is not "relu".
TEST_F(BulkFeedForwardNetworkTest, UnsupportedNonlinearityOption) {
const string kBadSpec = R"(network_unit {
parameters {
key: 'nonlinearity'
value: 'elu'
}
})";
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr("Non-linearity is not supported"));
}
// Tests that BulkFeedForwardNetwork fails if there is a recurrent link.
TEST_F(BulkFeedForwardNetworkTest, UnsupportedRecurrentLink) {
const string kBadSpec = R"(linked_feature {
source_component: 'test_component'
})";
EXPECT_THAT(Run(kBadSpec),
test::IsErrorWithSubstr(
"BulkFeedForwardNetwork forbids recurrent links"));
}
// Tests that the BulkFeedForwardNetwork works when there are no hidden layers,
// just a softmax that computes logits.
TEST_F(BulkFeedForwardNetworkTest, JustLogits) {
const string kSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
num_actions: 3)";
const float kWeight = 1.5;
const float kBias = 0.75;
AddWeights("softmax", kInputDim, kLogitsDim, kWeight);
AddBiases("softmax", kLogitsDim, kBias);
TF_ASSERT_OK(Run(kSpec));
EXPECT_EQ("logits", bulk_network_unit_->GetLogitsName());
ExpectMatrix(GetActivations("logits"), kNumSteps, kLogitsDim,
kInputDim * kEmbedding * kWeight + kBias);
}
// Tests that the BulkFeedForwardNetwork works with multiple hidden layers as
// well as a softmax that computes logits.
TEST_F(BulkFeedForwardNetworkTest, MultiLayer) {
const size_t kDims[] = {kInputDim, 4, 3, 2};
const string kSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '4,3'
}
}
num_actions: 2)";
const float kWeights[] = {-1.5, 1.0, 0.5};
const float kBiases[] = {0.75, -0.5, -1.0};
AddWeights("0", kDims[0], kDims[1], kWeights[0]);
AddBiases("0", kDims[1], kBiases[0]);
AddWeights("1", kDims[1], kDims[2], kWeights[1]);
AddBiases("1", kDims[2], kBiases[1]);
AddWeights("softmax", kDims[2], kDims[3], kWeights[2]);
AddBiases("softmax", kDims[3], kBiases[2]);
TF_ASSERT_OK(Run(kSpec));
EXPECT_EQ("logits", bulk_network_unit_->GetLogitsName());
float expected = Relu(kDims[0] * kWeights[0] + kBiases[0]);
ExpectMatrix(GetActivations("layer_0"), kNumSteps, kDims[1], expected);
expected = Relu(kDims[1] * expected * kWeights[1] + kBiases[1]);
ExpectMatrix(GetActivations("layer_1"), kNumSteps, kDims[2], expected);
ExpectMatrix(GetActivations("last_layer"), kNumSteps, kDims[2], expected);
expected = kDims[2] * expected * kWeights[2] + kBiases[2];
ExpectMatrix(GetActivations("logits"), kNumSteps, kDims[3], expected);
}
// Tests that the BulkFeedForwardNetwork does not produce logits and does not
// use the softmax variables when the component is deterministic.
TEST_F(BulkFeedForwardNetworkTest, NoLogitsOrSoftmaxWhenDeterministic) {
const size_t kDims[] = {kInputDim, 4};
const string kSpec = R"(num_actions: 1
fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '4'
}
})";
const float kWeight = -1.5;
const float kBias = 0.75;
// No "softmax" weights or biases.
AddWeights("0", kDims[0], kDims[1], kWeight);
AddBiases("0", kDims[1], kBias);
TF_ASSERT_OK(Run(kSpec));
// No specified logits layer.
EXPECT_TRUE(bulk_network_unit_->GetLogitsName().empty());
// No "logits" layer.
size_t unused_dimension = 0;
LayerHandle<float> unused_handle;
EXPECT_THAT(
network_state_manager_.LookupLayer(kTestComponentName, "logits",
&unused_dimension, &unused_handle),
test::IsErrorWithSubstr(
"Unknown layer 'logits' in component 'test_component'"));
// Hidden layer is still produced.
const float kExpected = Relu(kDims[0] * kEmbedding * kWeight + kBias);
ExpectMatrix(GetActivations("layer_0"), kNumSteps, kDims[1], kExpected);
ExpectMatrix(GetActivations("last_layer"), kNumSteps, kDims[1], kExpected);
}
// Tests that the BulkFeedForwardNetwork does not produce logits when
// omit_logits is true, even if there are actions.
TEST_F(BulkFeedForwardNetworkTest, NoLogitsOrSoftmaxWhenOmitLogitsTrue) {
const size_t kDims[] = {kInputDim, 4};
const string kSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 5
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '4'
}
parameters {
key: 'omit_logits'
value: 'true'
}
}
num_actions: 10)";
const float kWeight = 1.5;
const float kBias = 0.75;
// No "softmax" weights or biases.
AddWeights("0", kDims[0], kDims[1], kWeight);
AddBiases("0", kDims[1], kBias);
TF_ASSERT_OK(Run(kSpec));
// No specified logits layer.
EXPECT_TRUE(bulk_network_unit_->GetLogitsName().empty());
// No "logits" layer.
size_t unused_dimension = 0;
LayerHandle<float> unused_handle;
EXPECT_THAT(
network_state_manager_.LookupLayer(kTestComponentName, "logits",
&unused_dimension, &unused_handle),
test::IsErrorWithSubstr(
"Unknown layer 'logits' in component 'test_component'"));
// Hidden layer is still produced.
const float kExpected = kDims[0] * kEmbedding * kWeight + kBias;
ExpectMatrix(GetActivations("layer_0"), kNumSteps, kDims[1], kExpected);
ExpectMatrix(GetActivations("last_layer"), kNumSteps, kDims[1], kExpected);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/lstm_network_kernel.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// A network unit that evaluates an LSTM.
class BulkLSTMNetwork : public BulkNetworkUnit {
public:
// Implements BulkNetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return kernel_.Initialize(component_spec, variable_store,
network_state_manager, extension_manager);
}
tensorflow::Status ValidateInputDimension(size_t dimension) const override {
return tensorflow::Status::OK();
}
string GetLogitsName() const override { return kernel_.GetLogitsName(); }
tensorflow::Status Evaluate(Matrix<float> inputs,
SessionState *session_state) const override {
return kernel_.Apply(inputs, session_state);
}
private:
// Kernel that implements the LSTM.
LSTMNetworkKernel kernel_{/*bulk=*/true};
};
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT(BulkLSTMNetwork);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/bulk_network_unit.h"
#include "dragnn/runtime/flexible_matrix_kernel.h"
#include "dragnn/runtime/lstm_cell/cell_function.h"
#include "dragnn/runtime/test/helpers.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr size_t kNumSteps = 20;
constexpr size_t kNumActions = 10;
constexpr size_t kInputDim = 32;
constexpr size_t kHiddenDim = 8;
class BulkLSTMNetworkTest : public NetworkTestBase {
protected:
// Adds a blocked weight matrix with the |name| with the given dimensions and
// |fill_value|. If |is_flexible_matrix| is true, the variable is set up for
// use by the FlexibleMatrixKernel.
void AddWeights(const string &name, size_t input_dim, size_t output_dim,
float fill_value, bool is_flexible_matrix = false) {
constexpr int kBatchSize = LstmCellFunction<>::kBatchSize;
size_t output_padded =
kBatchSize * ((output_dim + kBatchSize - 1) / kBatchSize);
size_t num_views = (output_padded / kBatchSize) * input_dim;
string var_name = tensorflow::strings::StrCat(
kTestComponentName, "/", name,
is_flexible_matrix ? FlexibleMatrixKernel::kSuffix
: "/matrix/blocked48");
const std::vector<float> block(kBatchSize, fill_value);
const std::vector<std::vector<float>> blocks(num_views, block);
variable_store_.AddOrDie(
var_name, blocks, VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX);
variable_store_.SetBlockedDimensionOverride(
var_name, {input_dim, output_padded, kBatchSize});
}
// Adds a bias vector with the |name_suffix| with the given dimensions and
// |fill_value|.
void AddBiases(const string &name, size_t dimension, float fill_value) {
const string biases_name =
tensorflow::strings::StrCat(kTestComponentName, "/", name);
AddVectorVariable(biases_name, dimension, fill_value);
}
// Initializes the |bulk_network_unit_| from the |component_spec_text|. On
// error, returns non-OK.
tensorflow::Status Initialize(const string &component_spec_text) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(component_spec_text, &component_spec));
component_spec.set_name(kTestComponentName);
AddComponent(kTestComponentName);
TF_RETURN_IF_ERROR(
BulkNetworkUnit::CreateOrError("BulkLSTMNetwork", &bulk_network_unit_));
TF_RETURN_IF_ERROR(bulk_network_unit_->Initialize(
component_spec, &variable_store_, &network_state_manager_,
&extension_manager_));
TF_RETURN_IF_ERROR(bulk_network_unit_->ValidateInputDimension(kInputDim));
network_states_.Reset(&network_state_manager_);
StartComponent(kNumSteps);
session_state_.extensions.Reset(&extension_manager_);
return tensorflow::Status::OK();
}
// Evaluates the |bulk_network_unit_| on the |inputs|.
void Apply(const std::vector<std::vector<float>> &inputs) {
UniqueMatrix<float> input_matrix(inputs);
TF_ASSERT_OK(bulk_network_unit_->Evaluate(Matrix<float>(*input_matrix),
&session_state_));
}
// Returns the logits matrix.
Matrix<float> GetLogits() const {
return Matrix<float>(GetLayer(kTestComponentName, "logits"));
}
std::unique_ptr<BulkNetworkUnit> bulk_network_unit_;
};
TEST_F(BulkLSTMNetworkTest, NormalOperation) {
const string kSpec = R"(fixed_feature {
vocabulary_size: 50
embedding_dim: 32
size: 1
}
network_unit {
parameters {
key: 'hidden_layer_sizes'
value: '8'
}
}
num_actions: 10)";
constexpr float kEmbedding = 1.25;
constexpr float kWeight = 1.5;
// Same as above, with "softmax" weights and biases.
AddWeights("x_to_ico", kInputDim, 3 * kHiddenDim, kWeight);
AddWeights("h_to_ico", kHiddenDim, 3 * kHiddenDim, kWeight);
AddWeights("c2i", kHiddenDim, kHiddenDim, kWeight);
AddWeights("c2o", kHiddenDim, kHiddenDim, kWeight);
AddWeights("weights_softmax", kHiddenDim, kNumActions, kWeight,
/*is_flexible_matrix=*/true);
AddBiases("ico_bias", 3 * kHiddenDim, kWeight);
AddBiases("bias_softmax", kNumActions, kWeight);
TF_EXPECT_OK(Initialize(kSpec));
// Logits should exist.
EXPECT_EQ(bulk_network_unit_->GetLogitsName(), "logits");
const std::vector<float> row(kInputDim, kEmbedding);
const std::vector<std::vector<float>> rows(kNumSteps, row);
Apply(rows);
// Logits dimension matches "num_actions" above. We don't test the values very
// precisely here, and feel free to update if the cell function changes. Most
// value tests should be in lstm_cell/cell_function_test.cc.
Matrix<float> logits = GetLogits();
EXPECT_EQ(logits.num_rows(), kNumSteps);
EXPECT_EQ(logits.num_columns(), kNumActions);
EXPECT_NEAR(logits.row(0)[0], 10.6391, 0.1);
for (int row = 0; row < logits.num_rows(); ++row) {
for (const float value : logits.row(row)) {
EXPECT_EQ(value, logits.row(0)[0])
<< "With uniform weights, all logits should be equal.";
}
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/bulk_network_unit.h"
#include <vector>
#include "dragnn/runtime/network_unit.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
string BulkNetworkUnit::GetClassName(
const ComponentSpec &component_spec) {
// The network unit name specified in the |component_spec| is for the Python
// registry and cannot be passed directly to the C++ registry. The function
// below extracts the C++ registered name; e.g.,
// "some.module.FooNetwork" => "FooNetwork".
// We then prepend "Bulk" to distinguish it from the non-bulk version.
return tensorflow::strings::StrCat("Bulk",
NetworkUnit::GetClassName(component_spec));
}
} // namespace runtime
} // namespace dragnn
REGISTER_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Bulk Network Unit",
dragnn::runtime::BulkNetworkUnit);
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_BULK_NETWORK_UNIT_H_
#define DRAGNN_RUNTIME_BULK_NETWORK_UNIT_H_
#include <stddef.h>
#include <functional>
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "syntaxnet/registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Interface for network units for bulk inference.
//
// TODO(googleuser): The current approach assumes that fixed and
// linked embeddings are computed and concatenated outside the network unit,
// which is simple and composable. However, it could be more efficient to,
// e.g., pass the fixed and linked embeddings individually or compute them
// internally. That would elide the concatenation and could increase cache
// coherency.
class BulkNetworkUnit : public RegisterableClass<BulkNetworkUnit> {
public:
BulkNetworkUnit(const BulkNetworkUnit &that) = delete;
BulkNetworkUnit &operator=(const BulkNetworkUnit &that) = delete;
virtual ~BulkNetworkUnit() = default;
// Returns the bulk network unit class name specified in the |component_spec|.
static string GetClassName(const ComponentSpec &component_spec);
// Initializes this to the configuration in the |component_spec|. Retrieves
// pre-trained variables from the |variable_store|, which must outlive this.
// Adds layers and local operands to the |network_state_manager|, which must
// be positioned at the current component. Requests SessionState extensions
// from the |extension_manager|. On error, returns non-OK.
virtual tensorflow::Status Initialize(
const ComponentSpec &component_spec, VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) = 0;
// Returns OK iff this is compatible with the input |dimension|.
virtual tensorflow::Status ValidateInputDimension(size_t dimension) const = 0;
// Returns the name of the layer that contains classification logits, or an
// empty string if this does not produce logits. Requires that Initialize()
// was called.
virtual string GetLogitsName() const = 0;
// Evaluates this network on the bulk |inputs|, using intermediate operands
// and output layers in the |session_state|. On error, returns non-OK.
virtual tensorflow::Status Evaluate(Matrix<float> inputs,
SessionState *session_state) const = 0;
protected:
BulkNetworkUnit() = default;
private:
// Helps prevent use of the Create() method; use CreateOrError() instead.
using RegisterableClass<BulkNetworkUnit>::Create;
};
} // namespace runtime
} // namespace dragnn
DECLARE_SYNTAXNET_CLASS_REGISTRY("DRAGNN Runtime Bulk Network Unit",
dragnn::runtime::BulkNetworkUnit);
} // namespace syntaxnet
// Registers a subclass using its class name as a string.
#define DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT(subclass) \
REGISTER_SYNTAXNET_CLASS_COMPONENT( \
::syntaxnet::dragnn::runtime::BulkNetworkUnit, #subclass, subclass)
#endif // DRAGNN_RUNTIME_BULK_NETWORK_UNIT_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/bulk_network_unit.h"
#include <memory>
#include <string>
#include "dragnn/core/compute_session.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/network_states.h"
#include "dragnn/runtime/session_state.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Expects that the two pointers have the same address.
void ExpectSameAddress(const void *pointer1, const void *pointer2) {
EXPECT_EQ(pointer1, pointer2);
}
// A trivial implementation for tests.
class BulkFooNetwork : public BulkNetworkUnit {
public:
// Implements BulkNetworkUnit.
tensorflow::Status Initialize(const ComponentSpec &component_spec,
VariableStore *variable_store,
NetworkStateManager *network_state_manager,
ExtensionManager *extension_manager) override {
return tensorflow::Status::OK();
}
tensorflow::Status ValidateInputDimension(size_t dimension) const override {
return tensorflow::Status::OK();
}
string GetLogitsName() const override { return "foo_logits"; }
tensorflow::Status Evaluate(Matrix<float> inputs,
SessionState *session_state) const override {
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_BULK_NETWORK_UNIT(BulkFooNetwork);
// Tests that BulkNetworkUnit::GetClassName() resolves names properly.
TEST(BulkNetworkUnitTest, GetClassName) {
for (const string &registered_name :
{"FooNetwork",
"module.FooNetwork",
"some.long.path.to.module.FooNetwork"}) {
ComponentSpec component_spec;
component_spec.mutable_network_unit()->set_registered_name(registered_name);
EXPECT_EQ(BulkNetworkUnit::GetClassName(component_spec), "BulkFooNetwork");
}
}
// Tests that BulkNetworkUnits can be created via the registry.
TEST(BulkNetworkUnitTest, CreateOrError) {
std::unique_ptr<BulkNetworkUnit> foo;
TF_ASSERT_OK(BulkNetworkUnit::CreateOrError("BulkFooNetwork", &foo));
ASSERT_TRUE(foo != nullptr);
ExpectSameAddress(dynamic_cast<BulkFooNetwork *>(foo.get()), foo.get());
EXPECT_EQ(foo->GetLogitsName(), "foo_logits");
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Transformer that removes dropout settings.
class ClearDropoutComponentTransformer : public ComponentTransformer {
public:
// Implements ComponentTransformer.
tensorflow::Status Transform(const string &component_type,
ComponentSpec *component_spec) override {
for (FixedFeatureChannel &channel :
*component_spec->mutable_fixed_feature()) {
channel.clear_dropout_id();
channel.clear_dropout_keep_probability();
}
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(ClearDropoutComponentTransformer);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment