Commit edea2b67 authored by Terry Koo's avatar Terry Koo
Browse files

Remove runtime because reasons.

parent a4bb31d0
component {
name: "rnn"
transition_system {
registered_name: "shift-only"
parameters {
key: "left_to_right"
value: "false"
}
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "words-embedding-input"
part {
file_pattern: "resources/component_0_rnn/resource_0_words-embedding-input/part_0"
file_format: "tf-records"
record_format: "syntaxnet.TokenEmbedding"
}
}
resource {
name: "words-vocab-input"
part {
file_pattern: "resources/component_0_rnn/resource_1_words-vocab-input/part_0"
file_format: "text"
record_format: ""
}
}
resource {
name: "char-ngram-map"
part {
file_pattern: "resources/component_0_rnn/resource_2_char-ngram-map/part_0"
file_format: "text"
record_format: ""
}
}
resource {
name: "word-map"
part {
file_pattern: "resources/component_0_rnn/resource_3_word-map/part_0"
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_pattern: "resources/component_0_rnn/resource_4_label-map/part_0"
file_format: "text"
record_format: ""
}
}
fixed_feature {
name: "char_ngrams"
fml: "input.token { offset(-1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(0).char-ngram(min-length=1,max-length=3,mark-boundaries=true) offset(1).char-ngram(min-length=1,max-length=3,mark-boundaries=true) }"
embedding_dim: 32
vocabulary_size: 25788
size: 3
}
fixed_feature {
name: "words"
fml: "input.token.word(min-freq=2)"
embedding_dim: 64
vocabulary_size: 23769
size: 1
}
network_unit {
registered_name: "LSTMNetwork"
parameters {
key: "hidden_layer_sizes"
value: "128"
}
parameters {
key: "omit_logits"
value: "true"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 1
attention_component: ""
component_builder {
registered_name: "DynamicComponentBuilder"
}
}
component {
name: "tagger"
transition_system {
registered_name: "tagger"
parameters {
key: "parser_skip_deterministic"
value: "false"
}
}
resource {
name: "tag-map"
part {
file_pattern: "resources/component_1_tagger/resource_0_tag-map/part_0"
file_format: "text"
record_format: ""
}
}
resource {
name: "tag-to-category"
part {
file_pattern: "resources/component_1_tagger/resource_1_tag-to-category/part_0"
file_format: "text"
record_format: ""
}
}
resource {
name: "label-map"
part {
file_pattern: "resources/component_0_rnn/resource_4_label-map/part_0"
file_format: "text"
record_format: ""
}
}
linked_feature {
name: "recurrence"
fml: "bias(0)"
embedding_dim: 32
size: 1
source_component: "tagger"
source_translator: "history"
source_layer: "layer_0"
}
linked_feature {
name: "rnn"
fml: "input.focus"
embedding_dim: -1
size: 1
source_component: "rnn"
source_translator: "reverse-token"
source_layer: "layer_0"
}
network_unit {
registered_name: "FeedForwardNetwork"
parameters {
key: "hidden_layer_sizes"
value: "64,64"
}
}
backend {
registered_name: "SyntaxNetComponent"
}
num_actions: 45
attention_component: ""
component_builder {
registered_name: "DynamicComponentBuilder"
}
}
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/trained_model.h"
#include <unordered_set>
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status TrainedModel::Reset(const string &saved_model_dir) {
const std::unordered_set<string> tags = {tensorflow::kSavedModelTagServe};
tensorflow::SavedModelBundle saved_model;
TF_RETURN_IF_ERROR(
tensorflow::LoadSavedModel({}, {}, saved_model_dir, tags, &saved_model));
// Success; make modifications.
saved_model_.session = std::move(saved_model.session);
saved_model_.meta_graph_def = std::move(saved_model.meta_graph_def);
nodes_.clear();
const tensorflow::GraphDef &graph = saved_model_.meta_graph_def.graph_def();
for (const tensorflow::NodeDef &node : graph.node()) {
nodes_[node.name()] = &node;
}
return tensorflow::Status::OK();
}
tensorflow::Status TrainedModel::EvaluateTensor(
const string &name, tensorflow::Tensor *tensor) const {
if (saved_model_.session == nullptr) {
return tensorflow::errors::FailedPrecondition("TF Session is not active");
}
// For some reason, runtime hook nodes cannot be evaluated without feeding an
// input batch. An empty batch currently works, but if DRAGNN starts failing
// on empty batches, a reasonable alternative is a batch of empty strings.
const string input_name = "annotation/ComputeSession/InputBatch";
const tensorflow::Tensor empty_batch(tensorflow::DT_STRING,
tensorflow::TensorShape({0}));
// Evaluate the variable in the session.
std::vector<tensorflow::Tensor> outputs;
tensorflow::Status status = saved_model_.session->Run(
{{input_name, empty_batch}}, {name}, {}, &outputs);
if (!status.ok()) {
// Attach some extra information to the session error.
return tensorflow::Status(
status.code(),
tensorflow::strings::StrCat("Failed to evaluate tensor '", name,
"': ", status.error_message()));
}
if (outputs.size() != 1) {
return tensorflow::errors::Unknown("Expected exactly one output, but got ",
outputs.size(), " outputs");
}
*tensor = outputs[0];
return tensorflow::Status::OK();
}
tensorflow::Status TrainedModel::LookupNode(
const string &name, const tensorflow::NodeDef **node) const {
if (saved_model_.session == nullptr) {
return tensorflow::errors::FailedPrecondition("TF Session is not active");
}
const auto it = nodes_.find(name);
if (it == nodes_.end()) {
return tensorflow::errors::NotFound("Unknown node: '", name, "'");
}
*node = it->second;
return tensorflow::Status::OK();
}
tensorflow::Status TrainedModel::GraphDef(
const tensorflow::GraphDef **graph) const {
if (saved_model_.session == nullptr) {
return tensorflow::errors::FailedPrecondition("TF Session is not active");
}
*graph = &saved_model_.meta_graph_def.graph_def();
return tensorflow::Status::OK();
}
tensorflow::Status TrainedModel::Close() {
if (saved_model_.session == nullptr) {
return tensorflow::errors::FailedPrecondition("TF Session is not active");
}
tensorflow::Status status = saved_model_.session->Close();
saved_model_.session.reset();
saved_model_.meta_graph_def.Clear();
nodes_.clear();
return status;
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TRAINED_MODEL_H_
#define DRAGNN_RUNTIME_TRAINED_MODEL_H_
#include <map>
#include <string>
#include "syntaxnet/base.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A trained DRAGNN model, which can be queried for nodes and tensors.
class TrainedModel {
public:
// Creates an uninitialized model; call Reset() before use.
TrainedModel() = default;
// Loads the TF SavedModel at the |saved_model_dir|, replacing the current
// model, if any. On error, returns non-OK and modifies nothing.
tensorflow::Status Reset(const string &saved_model_dir);
// Evaluates the tensor with the |name| in the |session_| and sets |tensor| to
// the result. On error, returns non-OK and modifies nothing.
//
// NB: Tensors that are embedded inside a tf.while_loop() cannot be evaluated.
// Such evaluations fail with errors like "Retval[0] does not have value".
tensorflow::Status EvaluateTensor(const string &name,
tensorflow::Tensor *tensor) const;
// Finds the node with the |name| in the |graph_| and points the |node| at it.
// On error, returns non-OK and modifies nothing.
tensorflow::Status LookupNode(const string &name,
const tensorflow::NodeDef **node) const;
// Points |graph| at the GraphDef for the current model. It is an error if
// there is no current model.
tensorflow::Status GraphDef(const tensorflow::GraphDef **graph) const;
// Discards the current model. It is an error if there is no current model.
// On error, returns non-OK but still discards the model.
tensorflow::Status Close();
private:
// TF SavedModel that contains the trained DRAGNN model.
tensorflow::SavedModelBundle saved_model_;
// Nodes in the TF graph, indexed by name.
std::map<string, const tensorflow::NodeDef *> nodes_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_TRAINED_MODEL_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/trained_model.h"
#include <stddef.h>
#include <string>
#include "dragnn/core/test/generic.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Relative path to a saved model.
constexpr char kSavedModelDir[] = "dragnn/runtime/testdata/rnn_tagger";
// A valid tensor name in the test model and its dimensions.
constexpr char kTensorName[] = "tagger/weights_0/ExponentialMovingAverage";
constexpr size_t kTensorRows = 160;
constexpr size_t kTensorColumns = 64;
// Returns a valid saved model directory.
string GetSavedModelDir() {
return tensorflow::io::JoinPath(test::GetTestDataPrefix(), kSavedModelDir);
}
// Tests that TrainedModel can initialize itself from a valid saved model,
// retrieve tensors and nodes, and close itself. This is done in one test to
// avoid multiple (expensive) saved model loads.
TEST(TrainedModelTest, ResetQueryAndClose) {
TrainedModel trained_model;
TF_ASSERT_OK(trained_model.Reset(GetSavedModelDir()));
// Look up a valid tensor.
tensorflow::Tensor tensor;
TF_ASSERT_OK(trained_model.EvaluateTensor(kTensorName, &tensor));
ASSERT_EQ(tensor.dims(), 2);
EXPECT_EQ(tensor.dim_size(0), kTensorRows);
EXPECT_EQ(tensor.dim_size(1), kTensorColumns);
// Look up an invalid tensor.
EXPECT_FALSE(trained_model.EvaluateTensor("invalid", &tensor).ok());
// Still have the old tensor contents.
ASSERT_EQ(tensor.dims(), 2);
EXPECT_EQ(tensor.dim_size(0), kTensorRows);
EXPECT_EQ(tensor.dim_size(1), kTensorColumns);
// Look up a valid node. Note that the tensor name doubles as a node name.
const tensorflow::NodeDef *node = nullptr;
TF_ASSERT_OK(trained_model.LookupNode(kTensorName, &node));
ASSERT_NE(node, nullptr);
EXPECT_EQ(node->name(), kTensorName);
// Look up an invalid node.
ASSERT_THAT(trained_model.LookupNode("invalid", &node),
test::IsErrorWithSubstr("Unknown node"));
// Still have the old node.
ASSERT_NE(node, nullptr);
EXPECT_EQ(node->name(), kTensorName);
// Get the current Graph.
const tensorflow::GraphDef *graph_def = nullptr;
TF_ASSERT_OK(trained_model.GraphDef(&graph_def));
EXPECT_GT(graph_def->node_size(), 0);
// First Close() is OK, second fails because already closed.
TF_EXPECT_OK(trained_model.Close());
EXPECT_THAT(trained_model.Close(),
test::IsErrorWithSubstr("TF Session is not active"));
}
// Tests that TrainedModel::Reset() fails on an invalid path.
TEST(TrainedModelTest, InvalidPath) {
TrainedModel trained_model;
EXPECT_FALSE(trained_model.Reset("invalid/path").ok());
}
// Tests that TrainedModel::Close() fails if there is no model.
TEST(TrainedModelTest, CloseFailsBeforeReset) {
TrainedModel trained_model;
EXPECT_THAT(trained_model.Close(),
test::IsErrorWithSubstr("TF Session is not active"));
}
// Tests that TrainedModel::GraphDef() fails if there is no active session.
TEST(TrainedModelTest, GraphDefFailsBeforeReset) {
const tensorflow::GraphDef *graph_def = nullptr;
TrainedModel trained_model;
EXPECT_THAT(trained_model.GraphDef(&graph_def),
test::IsErrorWithSubstr("TF Session is not active"));
}
// Tests that TrainedModel::EvaluateTensor() fails if there is no model.
TEST(TrainedModelTest, EvaluateTensorFailsBeforeReset) {
TrainedModel trained_model;
tensorflow::Tensor tensor;
EXPECT_THAT(trained_model.EvaluateTensor("whatever", &tensor),
test::IsErrorWithSubstr("TF Session is not active"));
}
// Tests that TrainedModel::LookupNode() fails if there is no model.
TEST(TrainedModelTest, LookupNodeFailsBeforeReset) {
TrainedModel trained_model;
const tensorflow::NodeDef *node = nullptr;
EXPECT_THAT(trained_model.LookupNode("whatever", &node),
test::IsErrorWithSubstr("TF Session is not active"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/trained_model_variable_store.h"
#include "dragnn/runtime/math/types.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
tensorflow::Status TrainedModelVariableStore::Reset(
const string &saved_model_dir) {
TF_RETURN_IF_ERROR(trained_model_.Reset(saved_model_dir));
// Success; make modifications.
variables_.clear();
return tensorflow::Status::OK();
}
namespace {
// Copies flat data from the |tensor|, casted to T, into the |array| and points
// the |area| at it. On error, returns non-OK.
template <class T>
tensorflow::Status ExtractFlat(const tensorflow::Tensor &tensor,
std::vector<size_t> *dimensions,
UniqueAlignedArray *array,
MutableAlignedArea *area) {
const auto flat = tensor.flat<T>();
const size_t bytes = flat.size() * sizeof(T);
array->Reset(ComputeAlignedAreaSize(1, bytes));
TF_RETURN_IF_ERROR(area->Reset(array->view(), 1, bytes));
const MutableVector<T> row(area->view(0));
for (size_t i = 0; i < flat.size(); ++i) row[i] = flat(i);
dimensions->clear();
dimensions->push_back(flat.size());
return tensorflow::Status::OK();
}
// Copies the |tensor|, casted to T and reshaped as a matrix, into the |array|
// and points the |area| at it. Requires that the |tensor| is rank 2 or more.
// On error, returns non-OK.
template <class T>
tensorflow::Status ExtractMatrix(const tensorflow::Tensor &tensor,
std::vector<size_t> *dimensions,
UniqueAlignedArray *array,
MutableAlignedArea *area) {
if (tensor.dims() < 2) {
return tensorflow::errors::InvalidArgument(
"Tensor must be rank >= 2 but is rank ", tensor.dims());
}
// Flatten all dims except the inner-most, creating a matrix.
const auto reshaped = tensor.flat_inner_dims<T>();
const size_t num_rows = reshaped.dimension(0);
const size_t num_columns = reshaped.dimension(1);
*dimensions = {num_rows, num_columns};
const size_t view_size_bytes = num_columns * sizeof(T);
array->Reset(ComputeAlignedAreaSize(num_rows, view_size_bytes));
TF_RETURN_IF_ERROR(area->Reset(array->view(), num_rows, view_size_bytes));
MutableMatrix<T> matrix(*area);
for (size_t row = 0; row < num_rows; ++row) {
for (size_t column = 0; column < num_columns; ++column) {
matrix.row(row)[column] = reshaped(row, column);
}
}
return tensorflow::Status::OK();
}
// Copies a blocked matrix from the |tensor|, casted to T, into the |array| and
// points the |area| at it. Requires that the |tensor| is rank 3. On error,
// returns non-OK.
template <class T>
tensorflow::Status ExtractBlockedMatrix(const tensorflow::Tensor &tensor,
std::vector<size_t> *dimensions,
UniqueAlignedArray *array,
MutableAlignedArea *area) {
if (tensor.dims() != 3) {
return tensorflow::errors::InvalidArgument(
"Tensor must be rank 3 but is rank ", tensor.dims());
}
const size_t num_sub_matrices = tensor.dim_size(0);
const size_t num_rows = tensor.dim_size(1);
const size_t block_size = tensor.dim_size(2);
const size_t num_columns = num_sub_matrices * block_size;
*dimensions = {num_rows, num_columns, block_size};
// Given the order of dimensions in the |tensor|, flattening it into a matrix
// via flat_inner_dims() and copying it to the |area| is equivalent to copying
// the blocked matrix.
std::vector<size_t> unused_dimensions; // ignore non-blocked dimensions
return ExtractMatrix<T>(tensor, &unused_dimensions, array, area);
}
} // namespace
tensorflow::Status TrainedModelVariableStore::Lookup(
const string &name, VariableSpec::Format format,
std::vector<size_t> *dimensions, AlignedArea *area) {
const Key key(name, format);
const auto it = variables_.find(key);
if (it != variables_.end()) {
std::tie(std::ignore, *dimensions, *area) = it->second;
return tensorflow::Status::OK();
}
Variable variable;
TF_RETURN_IF_ERROR(GetVariableContents(name, format, &variable));
// Success; make modifications.
std::tie(std::ignore, *dimensions, *area) = variable;
variables_[key] = std::move(variable);
return tensorflow::Status::OK();
}
tensorflow::Status TrainedModelVariableStore::GetVariableContents(
const string &name, VariableSpec::Format format, Variable *variable) {
tensorflow::Tensor tensor;
TF_RETURN_IF_ERROR(trained_model_.EvaluateTensor(name, &tensor));
// Extract typed tensor data.
UniqueAlignedArray *array = &std::get<0>(*variable);
std::vector<size_t> *dimensions = &std::get<1>(*variable);
MutableAlignedArea *area = &std::get<2>(*variable);
if (tensor.dtype() == tensorflow::DT_FLOAT) {
switch (format) {
case VariableSpec::FORMAT_UNKNOWN:
return tensorflow::errors::InvalidArgument("Unknown variable format");
case VariableSpec::FORMAT_FLAT:
return ExtractFlat<float>(tensor, dimensions, array, area);
case VariableSpec::FORMAT_ROW_MAJOR_MATRIX:
return ExtractMatrix<float>(tensor, dimensions, array, area);
case VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX:
return ExtractBlockedMatrix<float>(tensor, dimensions, array, area);
}
} else if (tensor.dtype() == tensorflow::DT_BFLOAT16) {
switch (format) {
case VariableSpec::FORMAT_UNKNOWN:
return tensorflow::errors::InvalidArgument("Unknown variable format");
case VariableSpec::FORMAT_FLAT:
return ExtractFlat<tensorflow::bfloat16>(tensor, dimensions, array,
area);
case VariableSpec::FORMAT_ROW_MAJOR_MATRIX:
return ExtractMatrix<tensorflow::bfloat16>(tensor, dimensions, array,
area);
case VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX:
return ExtractBlockedMatrix<tensorflow::bfloat16>(tensor, dimensions,
array, area);
}
} else {
// TODO(googleuser): Add clauses for additional types as needed.
return tensorflow::errors::Unimplemented(
"Data type not supported: ", tensorflow::DataType_Name(tensor.dtype()));
}
}
tensorflow::Status TrainedModelVariableStore::Close() {
return trained_model_.Close();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TRAINED_MODEL_VARIABLE_STORE_H_
#define DRAGNN_RUNTIME_TRAINED_MODEL_VARIABLE_STORE_H_
#include <map>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/trained_model.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A variable store that extracts variables from a trained DRAGNN model. This
// should not be used in production (where ArrayVariableStore and its subclasses
// should be used), though it is convenient for experimentation.
class TrainedModelVariableStore : public VariableStore {
public:
// Creates an uninitialized store.
TrainedModelVariableStore() = default;
// Resets this to represent the variables defined by the TF saved model at the
// |saved_model_dir|. On error, returns non-OK and modifies nothing.
tensorflow::Status Reset(const string &saved_model_dir);
// Implements VariableStore.
using VariableStore::Lookup; // import Lookup<T>() convenience methods
tensorflow::Status Lookup(const string &name, VariableSpec::Format format,
std::vector<size_t> *dimensions,
AlignedArea *area) override;
tensorflow::Status Close() override;
private:
// A (name,format) key associated with a variable.
using Key = std::pair<string, VariableSpec::Format>;
// Extracted and formatted variable contents, as an aligned byte array and an
// area that provides a structured interpretation.
using Variable =
std::tuple<UniqueAlignedArray, std::vector<size_t>, MutableAlignedArea>;
// Extracts the contents of the variable named |name| in the |format| and
// stores the result in the |variable|. On error, returns non-OK.
tensorflow::Status GetVariableContents(const string &name,
VariableSpec::Format format,
Variable *variable);
// Trained DRAGNN model used to extract variables.
TrainedModel trained_model_;
// The already-extracted variables.
std::map<Key, Variable> variables_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_TRAINED_MODEL_VARIABLE_STORE_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/trained_model_variable_store.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/math/avx_vector_array.h"
#include "dragnn/runtime/math/float16_types.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
class TrainedModelVariableStoreTest : public ::testing::Test {
protected:
// Computes a value that accesses all bytes in the |view| or |area|. Useful
// for checking that a piece of memory is accessible.
size_t SumBytes(AlignedView view) {
size_t sum = 0;
for (size_t i = 0; i < view.size(); ++i) sum += view.data()[i];
return sum;
}
size_t SumBytes(AlignedArea area) {
size_t sum = 0;
for (size_t i = 0; i < area.num_views(); ++i) sum += SumBytes(area.view(i));
return sum;
}
// Returns the name of a tensor containing the blocked version of
// |kVariableName|, with the given |block_size|.
string GetBlockedVariableName(int block_size) const {
return tensorflow::strings::StrCat(kVariableNamePrefix, "/matrix/blocked",
block_size, "/ExponentialMovingAverage");
}
// Same as above, but returns the name of the bfloat16 variable.
string GetBfloat16VariableName(int block_size) const {
return tensorflow::strings::StrCat(kVariableNamePrefix, "/matrix/blocked",
block_size,
"/bfloat16/ExponentialMovingAverage");
}
// Path to a saved model file for tests. Expected to contain:
// * A tf.float32 variable named |kVariableName| with shape
// [|kVariableRows|, |kVariableColumns|].
// * A variable named |kUnsupportedTypeVariableName| whose type is not
// supported by the implementation.
// * A variable named |kLowRankVariableName| whose rank is < 2.
const string kSavedModelDir = tensorflow::io::JoinPath(
test::GetTestDataPrefix(), "dragnn/runtime/testdata/rnn_tagger");
// A valid variable name in the test model and its dimensions.
const string kVariableNamePrefix = "tagger/weights_0";
const string kVariableName = tensorflow::strings::StrCat(
kVariableNamePrefix, "/ExponentialMovingAverage");
const size_t kVariableRows = 160;
const size_t kVariableColumns = 64;
// A variable with unsupported type; this variable is tf.int32.
const string kUnsupportedTypeVariableName = "tagger/step";
// A variable whose rank is < 2; this is a scalar.
const string kLowRankVariableName = "tagger/bias_1";
// Variable store for tests.
TrainedModelVariableStore store_;
};
// Tests that TrainedModelVariableStore can be initialized from a valid model.
TEST_F(TrainedModelVariableStoreTest, ResetValid) {
TF_EXPECT_OK(store_.Reset(kSavedModelDir));
}
// Tests that TrainedModelVariableStore fails on a valid directory that doesn't
// actually contain a TF saved model, but can be re-Reset() on valid files.
TEST_F(TrainedModelVariableStoreTest, ResetInvalidDirectoryThenValid) {
EXPECT_FALSE(store_.Reset("/tmp").ok());
TF_EXPECT_OK(store_.Reset(kSavedModelDir));
}
// Tests that TrainedModelVariableStore fails on a non-directory, but can be
// re-Reset() on valid files.
TEST_F(TrainedModelVariableStoreTest, ResetNotADirectoryThenValid) {
EXPECT_FALSE(store_.Reset("/dev/null").ok());
TF_EXPECT_OK(store_.Reset(kSavedModelDir));
}
// Tests that TrainedModelVariableStore fails with missing files node scope, but
// can be re-Reset() on valid files.
TEST_F(TrainedModelVariableStoreTest, ResetMissingDirectoryThenValid) {
EXPECT_FALSE(store_.Reset("/missing/model/dir").ok());
TF_EXPECT_OK(store_.Reset(kSavedModelDir));
}
// Tests that TrainedModelVariableStore can only be closed once, and only after
// it is has been initialized.
TEST_F(TrainedModelVariableStoreTest, Close) {
EXPECT_THAT(store_.Close(),
test::IsErrorWithSubstr("TF Session is not active"));
TF_ASSERT_OK(store_.Reset(kSavedModelDir));
TF_EXPECT_OK(store_.Close());
EXPECT_THAT(store_.Close(),
test::IsErrorWithSubstr("TF Session is not active"));
}
// Tests that TrainedModelVariableStore can look up flat variables.
TEST_F(TrainedModelVariableStoreTest, LookupFlat) {
AlignedArea area;
std::vector<size_t> dimensions;
// Fail to look up a valid name before initialization.
EXPECT_THAT(store_.Lookup(kVariableName, VariableSpec::FORMAT_FLAT,
&dimensions, &area),
test::IsErrorWithSubstr("TF Session is not active"));
EXPECT_TRUE(area.empty()); // not modified
// Repeating the failed lookup should still fail.
EXPECT_THAT(store_.Lookup(kVariableName, VariableSpec::FORMAT_FLAT,
&dimensions, &area),
test::IsErrorWithSubstr("TF Session is not active"));
EXPECT_TRUE(area.empty()); // not modified
// Fail to look up an invalid name after initialization.
TF_ASSERT_OK(store_.Reset(kSavedModelDir));
EXPECT_FALSE(
store_
.Lookup("invalid/name", VariableSpec::FORMAT_FLAT, &dimensions, &area)
.ok());
EXPECT_TRUE(area.empty()); // not modified
// Successfully look up a valid name.
TF_ASSERT_OK(store_.Lookup(kVariableName, VariableSpec::FORMAT_FLAT,
&dimensions, &area));
EXPECT_FALSE(area.empty()); // modified
EXPECT_EQ(area.num_views(), 1);
EXPECT_EQ(area.view_size(), kVariableRows * kVariableColumns * sizeof(float));
// Try looking up the same name again.
area = AlignedArea();
TF_ASSERT_OK(store_.Lookup(kVariableName, VariableSpec::FORMAT_FLAT,
&dimensions, &area));
EXPECT_EQ(area.num_views(), 1);
EXPECT_EQ(area.view_size(), kVariableRows * kVariableColumns * sizeof(float));
// Check that the area can be accessed even after the |store| is closed.
TF_EXPECT_OK(store_.Close());
LOG(INFO) << "Logging to prevent elision by optimizer: " << SumBytes(area);
}
// Tests that TrainedModelVariableStore can look up row-major matrix variables.
TEST_F(TrainedModelVariableStoreTest, LookupRowMajorMatrix) {
AlignedArea area;
std::vector<size_t> dimensions;
// Fail to look up a valid name before initialization.
EXPECT_THAT(
store_.Lookup(kVariableName, VariableSpec::FORMAT_ROW_MAJOR_MATRIX,
&dimensions, &area),
test::IsErrorWithSubstr("TF Session is not active"));
EXPECT_TRUE(area.empty()); // not modified
// Repeating the failed lookup should still fail.
EXPECT_THAT(
store_.Lookup(kVariableName, VariableSpec::FORMAT_ROW_MAJOR_MATRIX,
&dimensions, &area),
test::IsErrorWithSubstr("TF Session is not active"));
EXPECT_TRUE(area.empty()); // not modified
// Fail to look up an invalid name after initialization.
TF_ASSERT_OK(store_.Reset(kSavedModelDir));
EXPECT_FALSE(store_
.Lookup("invalid/name",
VariableSpec::FORMAT_ROW_MAJOR_MATRIX, &dimensions,
&area)
.ok());
EXPECT_TRUE(area.empty()); // not modified
// Successfully look up a valid name.
TF_ASSERT_OK(store_.Lookup(kVariableName,
VariableSpec::FORMAT_ROW_MAJOR_MATRIX, &dimensions,
&area));
ASSERT_FALSE(area.empty()); // modified
EXPECT_EQ(dimensions, std::vector<size_t>({kVariableRows, kVariableColumns}));
EXPECT_EQ(area.num_views(), kVariableRows);
EXPECT_EQ(area.view_size(), kVariableColumns * sizeof(float));
// Try looking up the same name again.
area = AlignedArea();
dimensions.clear();
TF_ASSERT_OK(store_.Lookup(kVariableName,
VariableSpec::FORMAT_ROW_MAJOR_MATRIX, &dimensions,
&area));
EXPECT_EQ(dimensions, std::vector<size_t>({kVariableRows, kVariableColumns}));
EXPECT_EQ(area.num_views(), kVariableRows);
EXPECT_EQ(area.view_size(), kVariableColumns * sizeof(float));
// Check that the area can be accessed even after the |store| is closed.
TF_EXPECT_OK(store_.Close());
LOG(INFO) << "Logging to prevent elision by optimizer: " << SumBytes(area);
}
// Tests that the same contents can be retrieved in various formats, and that
// the content is the same asides from rearrangement.
TEST_F(TrainedModelVariableStoreTest, CompareFormats) {
Vector<float> flat;
Matrix<float> row_major_matrix;
TF_ASSERT_OK(store_.Reset(kSavedModelDir));
TF_ASSERT_OK(store_.Lookup(kVariableName, &flat));
TF_ASSERT_OK(store_.Lookup(kVariableName, &row_major_matrix));
ASSERT_EQ(flat.size(),
row_major_matrix.num_rows() * row_major_matrix.num_columns());
for (size_t flat_index = 0, row = 0; row < row_major_matrix.num_rows();
++row) {
for (size_t column = 0; column < row_major_matrix.num_columns();
++column, ++flat_index) {
EXPECT_EQ(row_major_matrix.row(row)[column], flat[flat_index]);
}
}
}
// Tests that TrainedModelVariableStore fails to retrieve a variable of an
// unsupported type.
TEST_F(TrainedModelVariableStoreTest, LookupUnsupportedType) {
AlignedArea area;
std::vector<size_t> dimensions;
TF_ASSERT_OK(store_.Reset(kSavedModelDir));
EXPECT_THAT(store_.Lookup(kUnsupportedTypeVariableName,
VariableSpec::FORMAT_FLAT, &dimensions, &area),
test::IsErrorWithSubstr("Data type not supported"));
}
// Tests that TrainedModelVariableStore fails to retrieve a variable of an
// unsupported type.
TEST_F(TrainedModelVariableStoreTest, LookupUnknownFormat) {
AlignedArea area;
std::vector<size_t> dimensions;
TF_ASSERT_OK(store_.Reset(kSavedModelDir));
EXPECT_THAT(store_.Lookup(kVariableName, VariableSpec::FORMAT_UNKNOWN,
&dimensions, &area),
test::IsErrorWithSubstr("Unknown variable format"));
}
// Tests that TrainedModelVariableStore fails to look up a variable without
// sufficient structure as an matrix.
TEST_F(TrainedModelVariableStoreTest, LookupInsufficientRank) {
AlignedArea area;
std::vector<size_t> dimensions;
TF_ASSERT_OK(store_.Reset(kSavedModelDir));
EXPECT_THAT(
store_.Lookup(kLowRankVariableName, VariableSpec::FORMAT_ROW_MAJOR_MATRIX,
&dimensions, &area),
test::IsErrorWithSubstr("Tensor must be rank >= 2"));
}
// Tests that TrainedModelVariableStore produces column-blocked row-major
// matrices with the same content as the non-blocked version. Checks that
// bfloat16 matrices are a permuted version of blocked matrices.
TEST_F(TrainedModelVariableStoreTest, ColumnBlockedComparison) {
const int kBlockSize = 32;
const string kBlockedVariableName = GetBlockedVariableName(kBlockSize);
const string kBfloat16VariableName = GetBfloat16VariableName(kBlockSize);
Matrix<float> plain_matrix;
BlockedMatrix<float> matrix;
BlockedMatrix<TruncatedFloat16> bfloat16_matrix;
TF_ASSERT_OK(store_.Reset(kSavedModelDir));
TF_ASSERT_OK(store_.Lookup(kVariableName, &plain_matrix));
TF_ASSERT_OK(store_.Lookup(kBlockedVariableName, &matrix));
TF_ASSERT_OK(store_.Lookup(kBfloat16VariableName, &bfloat16_matrix));
ASSERT_EQ(matrix.num_rows(), kVariableRows);
ASSERT_EQ(matrix.num_columns(), kVariableColumns);
ASSERT_EQ(matrix.block_size(), kBlockSize);
// Compare the content of the plain matrix with the blocked version.
for (int column = 0; column < matrix.num_columns(); ++column) {
const int column_block_index = column / kBlockSize;
const int index_in_block = column % kBlockSize;
for (int row = 0; row < matrix.num_rows(); ++row) {
const int block_index = column_block_index * matrix.num_rows() + row;
Vector<float> block = matrix.vector(block_index);
EXPECT_EQ(block[index_in_block], plain_matrix.row(row)[column]);
}
}
// Compare bfloat16-encoded values with float32 values.
ASSERT_EQ(matrix.num_vectors(), bfloat16_matrix.num_vectors());
ASSERT_EQ(matrix.block_size(), bfloat16_matrix.block_size());
ASSERT_EQ(matrix.num_rows(), bfloat16_matrix.num_rows());
ASSERT_EQ(matrix.num_columns(), bfloat16_matrix.num_columns());
for (int vector = 0; vector < matrix.num_vectors(); ++vector) {
const auto &matrix_vector = matrix.vector(vector);
const auto &bfloat16_vector = bfloat16_matrix.vector(vector);
for (int i = 0; i < matrix.block_size(); ++i) {
int permuted = FastUnpackPermutation(i);
const float matrix_value = matrix_vector[i];
const float bfloat16_value = bfloat16_vector[permuted].DebugToFloat();
EXPECT_NEAR(matrix_value, bfloat16_value, 5e-3);
}
}
}
// Tests that TrainedModelVariableStore overwrites the dimension vector passed
// to Lookup().
TEST_F(TrainedModelVariableStoreTest, OverwritesDimensions) {
const int kBlockSize = 32;
const string kBlockedVariableName = GetBlockedVariableName(kBlockSize);
TF_ASSERT_OK(store_.Reset(kSavedModelDir));
std::vector<VariableSpec::Format> formats{
VariableSpec::FORMAT_FLAT, VariableSpec::FORMAT_ROW_MAJOR_MATRIX,
VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX};
for (const auto &format : formats) {
std::vector<size_t> dimensions;
dimensions.push_back(1234);
AlignedArea area;
TF_ASSERT_OK(
store_.Lookup(kBlockedVariableName, format, &dimensions, &area));
EXPECT_NE(dimensions[0], 1234);
std::vector<size_t> expected_dimensions;
switch (format) {
case VariableSpec::FORMAT_UNKNOWN:
LOG(FATAL) << "Invalid format";
case VariableSpec::FORMAT_FLAT:
expected_dimensions = {kVariableRows * kVariableColumns};
break;
case VariableSpec::FORMAT_ROW_MAJOR_MATRIX:
// NB: We're fetching the rank-3 "/matrix/blockedNN" version and then
// reshaping into a matrix, so the dimensions are not the same as the
// plain matrix.
expected_dimensions = {kVariableRows * kVariableColumns / kBlockSize,
kBlockSize};
break;
case VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX:
expected_dimensions = {kVariableRows, kVariableColumns, kBlockSize};
break;
}
EXPECT_EQ(dimensions, expected_dimensions);
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/transition_system_traits.h"
#include <string>
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Note: The traits are currently simple enough to specify in one file. We can
// also use a registry-based system if this gets too complex.
// Returns true if the |component_spec| is deterministic.
bool IsDeterministic(const ComponentSpec &component_spec) {
return component_spec.num_actions() == 1;
}
// Returns true if the |component_spec| is sequential.
bool IsSequential(const ComponentSpec &component_spec) {
const string &name = component_spec.transition_system().registered_name();
return name == "char-shift-only" || //
name == "shift-only" || //
name == "tagger" || //
name == "morpher" || //
name == "heads" || //
name == "labels";
}
// Returns true if the |component_spec| specifies a left-to-right transition
// system. The default when unspecified is true.
bool IsLeftToRight(const ComponentSpec &component_spec) {
const auto &parameters = component_spec.transition_system().parameters();
const auto it = parameters.find("left_to_right");
if (it == parameters.end()) return true;
return tensorflow::str_util::Lowercase(it->second) != "false";
}
// Returns true if the |transition_system| is character-scale.
bool IsCharacterScale(const ComponentSpec &component_spec) {
const string &name = component_spec.transition_system().registered_name();
return //
name == "char-shift-only";
}
// Returns true if the |transition_system| is token-scale.
bool IsTokenScale(const ComponentSpec &component_spec) {
const string &name = component_spec.transition_system().registered_name();
return name == "shift-only" || //
name == "tagger" || //
name == "morpher" || //
name == "heads" || //
name == "labels";
}
} // namespace
TransitionSystemTraits::TransitionSystemTraits(
const ComponentSpec &component_spec)
: is_deterministic(IsDeterministic(component_spec)),
is_sequential(IsSequential(component_spec)),
is_left_to_right(IsLeftToRight(component_spec)),
is_character_scale(IsCharacterScale(component_spec)),
is_token_scale(IsTokenScale(component_spec)) {}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TRANSITION_SYSTEM_TRAITS_H_
#define DRAGNN_RUNTIME_TRANSITION_SYSTEM_TRAITS_H_
#include "dragnn/protos/spec.pb.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Traits describing the transition system used by some component.
struct TransitionSystemTraits {
// Creates a set of traits describing the |component_spec|.
explicit TransitionSystemTraits(const ComponentSpec &component_spec);
// Whether the transition system is deterministic---i.e., it can be advanced
// without computing logits and making predictions.
const bool is_deterministic;
// Whether the transition system is sequential---i.e., compatible with
// SequenceBackend, SequenceExtractor, and so on.
const bool is_sequential;
// Whether the transition system advances from left to right in the underlying
// input sequence. This only makes sense if |sequential| is true.
const bool is_left_to_right;
// Whether the transition steps correspond to characters or tokens. This only
// makes sense if |sequential| is true.
//
// TODO(googleuser): Distinguish between full-text character transition systems
// and per-word ones?
const bool is_character_scale;
const bool is_token_scale;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_TRANSITION_SYSTEM_TRAITS_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/transition_system_traits.h"
#include <string>
#include <utility>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns a ComponentSpec that uses the |transition_system|, is configured to
// run left-to-right if |left_to_right| is true, and whose transition system
// predicts |num_actions| actions.
ComponentSpec MakeTestSpec(const string &transition_system, bool left_to_right,
int num_actions) {
ComponentSpec component_spec;
component_spec.set_num_actions(num_actions);
component_spec.mutable_transition_system()->set_registered_name(
transition_system);
component_spec.mutable_transition_system()->mutable_parameters()->insert(
{"left_to_right", left_to_right ? "true" : "false"});
return component_spec;
}
// Tests that boolean values are case-insensitive.
TEST(TransitionSystemTraitsAttributeParsingTest, CaseInsensitiveBooleanValues) {
ComponentSpec component_spec = MakeTestSpec("shift-only", false, 1);
auto &parameters =
*component_spec.mutable_transition_system()->mutable_parameters();
for (const string &true_value : {"TRUE", "True"}) {
parameters["left_to_right"] = true_value;
TransitionSystemTraits traits(component_spec);
EXPECT_TRUE(traits.is_left_to_right);
}
for (const string &false_value : {"FALSE", "False"}) {
parameters["left_to_right"] = false_value;
TransitionSystemTraits traits(component_spec);
EXPECT_FALSE(traits.is_left_to_right);
}
}
// Parameterized on (left-to-right, deterministic).
class TransitionSystemTraitsTest
: public ::testing::TestWithParam<::testing::tuple<bool, bool>> {
protected:
// Returns the test parameters.
bool left_to_right() const { return ::testing::get<0>(GetParam()); }
bool deterministic() const { return ::testing::get<1>(GetParam()); }
// Returns a ComponentSpec for the |transition_system|.
ComponentSpec MakeSpec(const string &transition_system) {
return MakeTestSpec(transition_system, left_to_right(),
deterministic() ? 1 : 10);
}
};
INSTANTIATE_TEST_CASE_P(LeftToRightAndDeterministic, TransitionSystemTraitsTest,
::testing::Combine(::testing::Bool(),
::testing::Bool()));
// Tests the traits of an unknown transition system.
TEST_P(TransitionSystemTraitsTest, Unknown) {
TransitionSystemTraits traits(MakeSpec("unknown"));
EXPECT_EQ(traits.is_deterministic, deterministic());
EXPECT_FALSE(traits.is_sequential);
EXPECT_EQ(traits.is_left_to_right, left_to_right());
EXPECT_FALSE(traits.is_character_scale);
EXPECT_FALSE(traits.is_token_scale);
}
// Tests the traits of the "char-shift-only" transition system.
TEST_P(TransitionSystemTraitsTest, CharShiftOnly) {
TransitionSystemTraits traits(MakeSpec("char-shift-only"));
EXPECT_EQ(traits.is_deterministic, deterministic());
EXPECT_TRUE(traits.is_sequential);
EXPECT_EQ(traits.is_left_to_right, left_to_right());
EXPECT_TRUE(traits.is_character_scale);
EXPECT_FALSE(traits.is_token_scale);
}
// Tests the traits of the "shift-only" transition system.
TEST_P(TransitionSystemTraitsTest, ShiftOnly) {
TransitionSystemTraits traits(MakeSpec("shift-only"));
EXPECT_EQ(traits.is_deterministic, deterministic());
EXPECT_TRUE(traits.is_sequential);
EXPECT_EQ(traits.is_left_to_right, left_to_right());
EXPECT_FALSE(traits.is_character_scale);
EXPECT_TRUE(traits.is_token_scale);
}
// Tests the traits of the "tagger" transition system.
TEST_P(TransitionSystemTraitsTest, Tagger) {
TransitionSystemTraits traits(MakeSpec("tagger"));
EXPECT_EQ(traits.is_deterministic, deterministic());
EXPECT_TRUE(traits.is_sequential);
EXPECT_EQ(traits.is_left_to_right, left_to_right());
EXPECT_FALSE(traits.is_character_scale);
EXPECT_TRUE(traits.is_token_scale);
}
// Tests the traits of the "morpher" transition system.
TEST_P(TransitionSystemTraitsTest, Morpher) {
TransitionSystemTraits traits(MakeSpec("morpher"));
EXPECT_EQ(traits.is_deterministic, deterministic());
EXPECT_TRUE(traits.is_sequential);
EXPECT_EQ(traits.is_left_to_right, left_to_right());
EXPECT_FALSE(traits.is_character_scale);
EXPECT_TRUE(traits.is_token_scale);
}
// Tests the traits of the "heads" transition system.
TEST_P(TransitionSystemTraitsTest, Heads) {
TransitionSystemTraits traits(MakeSpec("heads"));
EXPECT_EQ(traits.is_deterministic, deterministic());
EXPECT_TRUE(traits.is_sequential);
EXPECT_EQ(traits.is_left_to_right, left_to_right());
EXPECT_FALSE(traits.is_character_scale);
EXPECT_TRUE(traits.is_token_scale);
}
// Tests the traits of the "labels" transition system.
TEST_P(TransitionSystemTraitsTest, Labels) {
TransitionSystemTraits traits(MakeSpec("labels"));
EXPECT_EQ(traits.is_deterministic, deterministic());
EXPECT_TRUE(traits.is_sequential);
EXPECT_EQ(traits.is_left_to_right, left_to_right());
EXPECT_FALSE(traits.is_character_scale);
EXPECT_TRUE(traits.is_token_scale);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TYPE_KEYED_SET_H_
#define DRAGNN_RUNTIME_TYPE_KEYED_SET_H_
#include <map>
#include <utility>
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A heterogeneous set of type-keyed objects. Objects of any type can be added,
// but this can only hold at most one object of each type.
//
// Note that this class does not have any locking, so threads must externally
// coordinate to ensure that every instance of this set is only accessed by one
// thread at a time. When used via SessionState, these conditions are enforced
// by the runtime framework.
class TypeKeyedSet {
public:
// Creates an empty set.
TypeKeyedSet() = default;
// Moves all objects from |that| to this. Afterwards, the objects in this are
// address-equal to the objects originally in |that|.
TypeKeyedSet(TypeKeyedSet &&that);
TypeKeyedSet &operator=(TypeKeyedSet &&that);
~TypeKeyedSet() { Clear(); }
// Removes all objects from this set.
void Clear();
// Returns the T in this set, creating it first via T() if needed.
template <class T>
T &Get();
private:
// Function that can delete an untyped pointer using the proper type.
using Deleter = void (*)(void *);
// Deletes the |object| as a T. All Deleters point to this function.
template <class T>
static void DeleteAs(void *object);
// Mapping from deleter to object. This owns the objects.
std::map<Deleter, void *> objects_;
};
// Implementation details below.
inline TypeKeyedSet::TypeKeyedSet(TypeKeyedSet &&that)
: objects_(std::move(that.objects_)) {
that.objects_.clear();
}
inline TypeKeyedSet &TypeKeyedSet::operator=(TypeKeyedSet &&that) {
Clear();
objects_ = std::move(that.objects_);
that.objects_.clear();
return *this;
}
inline void TypeKeyedSet::Clear() {
for (const auto &it : objects_) it.first(it.second);
objects_.clear();
}
template <class T>
T &TypeKeyedSet::Get() {
// Implementation notes:
// * DeleteAs<T>() is unique per T, so keying on its instantiation it is
// equivalent to keying on type, as desired.
// * The |object| pointer below is doubly-indirect: it is a reference to a
// void* pointer that lives in the |objects_| map.
// * If there was previously no entry in |objects_|, then |object| will be
// value-initialized (i.e., nulled), and we reassign it to a new T().
void *&object = objects_[&DeleteAs<T>];
if (object == nullptr) object = new T();
return *reinterpret_cast<T *>(object);
}
template <class T>
void TypeKeyedSet::DeleteAs(void *object) {
delete reinterpret_cast<T *>(object);
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_TYPE_KEYED_SET_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/type_keyed_set.h"
#include <utility>
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Dummy struct for tests.
struct Foo {
float value = -1.5;
};
// Type aliases to exercise usage of aliases as type keys.
using OtherInt = int;
using OtherFoo = Foo;
// Tests that TypeKeyedSet::Get() returns the same object once created.
TEST(TypeKeyedSetTest, Get) {
TypeKeyedSet set;
// Get a couple types, and check for default-constructed values.
int &int_object = set.Get<int>();
ASSERT_NE(&int_object, nullptr);
EXPECT_EQ(int_object, 0); // due to T()
int_object = 2718;
Foo &foo_object = set.Get<Foo>();
ASSERT_NE(&foo_object, nullptr);
EXPECT_EQ(foo_object.value, -1.5); // due to T()
foo_object.value = 3141.5;
// Get the same types again, this time using type aliases, and check for
// address and value equality.
OtherInt &other_int_object = set.Get<OtherInt>();
EXPECT_EQ(&other_int_object, &int_object);
EXPECT_EQ(other_int_object, 2718);
OtherFoo &other_foo_object = set.Get<OtherFoo>();
EXPECT_EQ(&other_foo_object, &foo_object);
EXPECT_EQ(other_foo_object.value, 3141.5);
}
// Tests that TypeKeyedSet::Clear() removes existing values.
TEST(TypeKeyedSetTest, Clear) {
// Create a set with some values.
TypeKeyedSet set;
int &int_object = set.Get<int>();
int_object = 2718;
Foo &foo_object = set.Get<Foo>();
foo_object.value = 3141.5;
// Clear the set and check that the values are now defaulted.
set.Clear();
EXPECT_EQ(set.Get<int>(), 0);
EXPECT_EQ(set.Get<Foo>().value, -1.5);
}
// Tests that TypeKeyedSet supports move construction.
TEST(TypeKeyedSetTest, MoveConstruction) {
TypeKeyedSet set1;
// Insert a couple of values.
int &int_object = set1.Get<int>();
int_object = 2718;
Foo &foo_object = set1.Get<Foo>();
foo_object.value = 3141.5;
// Move-construct another set, and check address and value equality.
TypeKeyedSet set2(std::move(set1));
OtherInt &other_int_object = set2.Get<OtherInt>();
EXPECT_EQ(&other_int_object, &int_object);
EXPECT_EQ(other_int_object, 2718);
OtherFoo &other_foo_object = set2.Get<OtherFoo>();
EXPECT_EQ(&other_foo_object, &foo_object);
EXPECT_EQ(other_foo_object.value, 3141.5);
}
// Tests that TypeKeyedSet supports move assignment.
TEST(TypeKeyedSetTest, MoveAssignment) {
// Create one set with some values.
TypeKeyedSet set1;
int &int_object = set1.Get<int>();
int_object = 2718;
Foo &foo_object = set1.Get<Foo>();
foo_object.value = 3141.5;
// Create another set with different values, to be overwritten.
TypeKeyedSet set2;
set2.Get<int>() = 123;
set2.Get<Foo>().value = 76.5;
// Move-assign to another set, and check address and value equality.
set2 = std::move(set1);
OtherInt &other_int_object = set2.Get<OtherInt>();
EXPECT_EQ(&other_int_object, &int_object);
EXPECT_EQ(other_int_object, 2718);
OtherFoo &other_foo_object = set2.Get<OtherFoo>();
EXPECT_EQ(&other_foo_object, &foo_object);
EXPECT_EQ(other_foo_object.value, 3141.5);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/unicode_dictionary.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns a string representation of the byte sequence of the |character|.
string CharacterDebugString(const string &character) {
const auto *bytes = reinterpret_cast<const uint8 *>(character.data());
string debug = "[";
for (int i = 0; i < character.size(); ++i) {
tensorflow::strings::StrAppend(&debug, i == 0 ? "" : " ", bytes[i]);
}
tensorflow::strings::StrAppend(&debug, "]");
return debug;
}
} // namespace
UnicodeDictionary::UnicodeDictionary() { Clear(); }
UnicodeDictionary::UnicodeDictionary(const string &character_map_path,
int min_frequency, int max_num_terms) {
TF_CHECK_OK(Reset(
TermFrequencyMap(character_map_path, min_frequency, max_num_terms)));
}
void UnicodeDictionary::Clear() {
size_ = 0;
for (int32 &index : single_byte_indices_) index = -1;
multi_byte_indices_.clear();
}
tensorflow::Status UnicodeDictionary::Reset(
const TermFrequencyMap &character_map) {
Clear();
size_ = character_map.Size();
for (int32 index = 0; index < character_map.Size(); ++index) {
const string &character = character_map.GetTerm(index);
if (character.empty()) {
return tensorflow::errors::InvalidArgument("Term ", index, " is empty");
}
const size_t correct_size = UniLib::OneCharLen(character.data());
if (character.size() != correct_size) {
return tensorflow::errors::InvalidArgument(
"Term ", index, " should have size ", correct_size, ": ",
CharacterDebugString(character));
}
if (!UniLib::IsUTF8ValidCodepoint(character)) {
return tensorflow::errors::InvalidArgument(
"Term ", index,
" is not valid UTF-8: ", CharacterDebugString(character));
}
const auto *bytes = reinterpret_cast<const uint8 *>(character.data());
if (character.size() == 1) {
DCHECK_EQ(single_byte_indices_[*bytes], -1);
single_byte_indices_[*bytes] = index;
} else {
const uint32 key = MultiByteKey(bytes, character.size());
DCHECK(multi_byte_indices_.find(key) == multi_byte_indices_.end());
multi_byte_indices_[key] = index;
}
}
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_UNICODE_DICTIONARY_H_
#define DRAGNN_RUNTIME_UNICODE_DICTIONARY_H_
#include <stddef.h>
#include <unordered_map>
#include <string>
#include "syntaxnet/base.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "util/utf8/unilib.h"
#include "util/utf8/unilib_utf8_utils.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A mapping from Unicode characters to indices.
//
// TODO(googleuser): Try integrating break chars into this mapping, maybe just for
// the ASCII break chars. They could be mapped directly to the break ID, so all
// one-byte characters are handled directly.
class UnicodeDictionary {
public:
// Creates an empty mapping.
UnicodeDictionary();
// Loads a TermFrequencyMap from the |character_map_path| while applying the
// |min_frequency| and |max_num_terms|, and Reset()s this from it. On error,
// dies. This is for use in SharedStore; prefer Initialize() otherwise.
UnicodeDictionary(const string &character_map_path, int min_frequency,
int max_num_terms);
// Resets this to the |character_map|. On error, returns non-OK.
tensorflow::Status Reset(const TermFrequencyMap &character_map);
// Returns the index of the UTF-8 character spanning [|data|,|data|+|size|),
// or the |unknown_index| if not present in this.
int32 Lookup(const char *data, size_t size, int32 unknown_index) const;
// Accessors.
size_t size() const { return size_; }
private:
// Removes all entries from this mapping.
void Clear();
// Returns an integer that uniquely identifies the multi-byte UTF-8 character
// spanning [|bytes|,|bytes|+|size|). Note that the returned value is not a
// Unicode codepoint.
static uint32 MultiByteKey(const uint8 *bytes, size_t size);
// Number of entries in this mapping.
size_t size_ = 0;
// Dense mapping from single-byte UTF-8 (i.e., ASCII) character to index, or
// -1 if unmapped.
int32 single_byte_indices_[128];
// Sparse mapping from multi-byte UTF-8 character to index.
std::unordered_map<uint32, int32> multi_byte_indices_;
};
// Implementation details below.
inline int32 UnicodeDictionary::Lookup(const char *data, size_t size,
int32 unknown_index) const {
DCHECK_GE(size, 1);
DCHECK_EQ(size, UniLib::OneCharLen(data));
DCHECK(UniLib::IsUTF8ValidCodepoint(string(data, size)));
const auto *bytes = reinterpret_cast<const uint8 *>(data);
if (size == 1) {
// Look up single-byte characters in the dense mapping.
DCHECK_LT(*bytes, 128);
const int32 index = single_byte_indices_[*bytes];
return index >= 0 ? index : unknown_index;
} else {
// Look up multi-byte characters in the sparse mapping.
const auto it = multi_byte_indices_.find(MultiByteKey(bytes, size));
return it != multi_byte_indices_.end() ? it->second : unknown_index;
}
}
inline uint32 UnicodeDictionary::MultiByteKey(const uint8 *bytes, size_t size) {
DCHECK_GE(size, 2);
DCHECK_LE(size, 4);
uint32 value = static_cast<uint32>(bytes[0]) | //
static_cast<uint32>(bytes[1]) << 8;
switch (size) {
case 4:
value |= static_cast<uint32>(bytes[3]) << 24;
TF_FALLTHROUGH_INTENDED;
case 3:
value |= static_cast<uint32>(bytes[2]) << 16;
}
return value;
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_UNICODE_DICTIONARY_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/unicode_dictionary.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "third_party/utf/utf.h"
#include "util/utf8/unilib.h"
#include "util/utf8/unilib_utf8_utils.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr char kInvalidUtf8[] = "\xff\xff\xff\xff";
constexpr char k1ByteCharacter[] = "a";
constexpr char k2ByteCharacter[] = "¼";
constexpr char k3ByteCharacter[] = "好";
constexpr char k4ByteCharacter[] = "𠜎";
// NB: String sizes are one more than expected from the trailing NUL.
static_assert(sizeof(k1ByteCharacter) / sizeof(char) == 2,
"1-byte character has the wrong size");
static_assert(sizeof(k2ByteCharacter) / sizeof(char) == 3,
"2-byte character has the wrong size");
static_assert(sizeof(k3ByteCharacter) / sizeof(char) == 4,
"3-byte character has the wrong size");
static_assert(sizeof(k4ByteCharacter) / sizeof(char) == 5,
"4-byte character has the wrong size");
// Tests that the dictionary is empty by default.
TEST(UnicodeDictionaryTest, EmptyByDefault) {
UnicodeDictionary dictionary;
EXPECT_EQ(dictionary.size(), 0);
EXPECT_EQ(dictionary.Lookup(k1ByteCharacter, 1, -123), -123);
EXPECT_EQ(dictionary.Lookup(k2ByteCharacter, 2, -123), -123);
EXPECT_EQ(dictionary.Lookup(k3ByteCharacter, 3, -123), -123);
EXPECT_EQ(dictionary.Lookup(k4ByteCharacter, 4, -123), -123);
}
// Tests that the dictionary can be reset to a copy of a term map.
TEST(UnicodeDictionaryTest, Reset) {
TermFrequencyMap character_map;
ASSERT_EQ(character_map.Increment(k1ByteCharacter), 0);
ASSERT_EQ(character_map.Increment(k2ByteCharacter), 1);
ASSERT_EQ(character_map.Increment(k3ByteCharacter), 2);
ASSERT_EQ(character_map.Increment(k4ByteCharacter), 3);
UnicodeDictionary dictionary;
TF_ASSERT_OK(dictionary.Reset(character_map));
EXPECT_EQ(dictionary.size(), 4);
EXPECT_EQ(dictionary.Lookup(k1ByteCharacter, 1, -123), 0);
EXPECT_EQ(dictionary.Lookup(k2ByteCharacter, 2, -123), 1);
EXPECT_EQ(dictionary.Lookup(k3ByteCharacter, 3, -123), 2);
EXPECT_EQ(dictionary.Lookup(k4ByteCharacter, 4, -123), 3);
}
// Tests that the dictionary fails if a character is empty.
TEST(UnicodeDictionaryTest, EmptyCharacter) {
TermFrequencyMap character_map;
ASSERT_EQ(character_map.Increment(""), 0);
UnicodeDictionary dictionary;
EXPECT_THAT(dictionary.Reset(character_map),
test::IsErrorWithSubstr("Term 0 is empty"));
}
// Tests that the dictionary fails if a term contains more than one character.
TEST(UnicodeDictionaryTest, MultipleCharacters) {
TermFrequencyMap character_map;
ASSERT_EQ(character_map.Increment("1234"), 0);
UnicodeDictionary dictionary;
EXPECT_THAT(dictionary.Reset(character_map),
test::IsErrorWithSubstr("Term 0 should have size 1"));
}
// Tests that the dictionary fails if a character is invalid.
TEST(UnicodeDictionaryTest, InvalidUtf8) {
TermFrequencyMap character_map;
ASSERT_EQ(character_map.Increment(kInvalidUtf8), 0);
UnicodeDictionary dictionary;
EXPECT_THAT(dictionary.Reset(character_map),
test::IsErrorWithSubstr("Term 0 is not valid UTF-8"));
}
// Tests that the dictionary can be constructed from a file.
TEST(UnicodeDictionaryTest, ConstructFromFile) {
// Recall that terms are loaded in order of descending frequency.
const string character_map_path = WriteTermMap({{"too-infrequent", 1},
{k1ByteCharacter, 2},
{k2ByteCharacter, 3},
{k3ByteCharacter, 4},
{k4ByteCharacter, 5}});
const UnicodeDictionary dictionary(character_map_path, 2, 0);
EXPECT_EQ(dictionary.size(), 4);
EXPECT_EQ(dictionary.Lookup(k1ByteCharacter, 1, -123), 3);
EXPECT_EQ(dictionary.Lookup(k2ByteCharacter, 2, -123), 2);
EXPECT_EQ(dictionary.Lookup(k3ByteCharacter, 3, -123), 1);
EXPECT_EQ(dictionary.Lookup(k4ByteCharacter, 4, -123), 0);
}
// Tests that the dictionary constructor dies on error.
TEST(UnicodeDictionaryTest, ConstructorDiesOnError) {
const string bad_path = WriteTermMap({{"1234", 1}});
EXPECT_DEATH(UnicodeDictionary dictionary(bad_path, 0, 0),
"Term 0 should have size 1");
}
// Tests that the dictionary can map all valid codepoints.
TEST(UnicodeDictionaryTest, AllValidCodepoints) {
TermFrequencyMap character_map;
for (Rune rune = 0; rune < Runemax; ++rune) {
// Some codepoints are considered invalid, and UnicodeDictionary::Reset()
// will fail if it encounters them (see the InvalidUtf8 test). Skip those
// since we've already tested this in the "InvalidUtf8" test.
if (!UniLib::IsValidCodepoint(rune)) continue;
char data[UTFmax];
const int size = runetochar(data, &rune);
const string character(data, size);
const int index = character_map.Size();
ASSERT_EQ(character_map.Increment(character), index);
}
UnicodeDictionary dictionary;
TF_ASSERT_OK(dictionary.Reset(character_map));
for (int index = 0; index < character_map.Size(); ++index) {
const string &character = character_map.GetTerm(index);
EXPECT_EQ(dictionary.Lookup(character.data(), character.size(), -1), index);
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_VARIABLE_STORE_H_
#define DRAGNN_RUNTIME_VARIABLE_STORE_H_
#include <string>
#include <vector>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Interface for a store holding named, precomputed variables. Implementations
// must be thread-compatible.
class VariableStore {
public:
VariableStore(const VariableStore &that) = delete;
VariableStore &operator=(const VariableStore &that) = delete;
virtual ~VariableStore() = default;
// Looks for the variable with the |name|, formats its content according to
// the requested |format| (see details below), and points the |area| at the
// result. The content of the variable before formatting is its content in
// the Python codebase. The |area| is valid while this lives, even after
// Close(). On error, returns non-OK and modifies nothing.
//
// Upon success the output |dimensions| will be cleared and assigned to
// the set of dimensions (num_elements,) in case of vectors, (num_rows,
// num_columns) in case of regular matrices, and (num_rows, num_columns,
// block_size) in case of blocked matrices.
//
// FORMAT_FLAT:
// Flattens the variable as if by tf.reshape(var, [-1]), and sets the |area|
// to a single sub-view that points at the flat array.
//
// FORMAT_ROW_MAJOR_MATRIX:
// Reshapes the variable into a matrix as if by tf.reshape(var, [-1, D]),
// where D is the variable's innermost dimension. Points each sub-view of
// the |area| at the corresponding row of the formatted matrix. Requires
// that the variable has rank at least 2.
//
// FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX:
// The variable must have shape [num_sub_matrices, num_rows, block_size],
// and is imported as a column-blocked row-major matrix, as documented in
// BlockedMatrixFormat (in math/types.h). The matrix may also be padded.
virtual tensorflow::Status Lookup(const string &name,
VariableSpec::Format format,
std::vector<size_t> *dimensions,
AlignedArea *area) = 0;
// Looks up a FORMAT_FLAT variable as a Vector<T>.
template <class T>
tensorflow::Status Lookup(const string &name, Vector<T> *vector);
// Looks up a FORMAT_ROW_MAJOR_MATRIX as a Matrix<T>.
template <class T>
tensorflow::Status Lookup(const string &name, Matrix<T> *matrix);
// Looks up a FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX as a BlockedMatrix<T>.
template <class T>
tensorflow::Status Lookup(const string &name, BlockedMatrix<T> *matrix);
// Releases intermediate resources, if any. Does not invalidate the contents
// of variables returned by previous calls to Lookup*(), but future calls to
// Lookup*() are unsupported. On error, returns non-OK.
virtual tensorflow::Status Close() = 0;
protected:
VariableStore() = default;
};
// Implementation details below.
template <class T>
tensorflow::Status VariableStore::Lookup(const string &name,
Vector<T> *vector) {
AlignedArea area;
std::vector<size_t> dimensions;
TF_RETURN_IF_ERROR(
Lookup(name, VariableSpec::FORMAT_FLAT, &dimensions, &area));
if (area.num_views() != 1) {
return tensorflow::errors::FailedPrecondition(
"Vector variable '", name, "' should have 1 sub-view but has ",
area.num_views());
}
if (area.view_size() % sizeof(T) != 0) {
return tensorflow::errors::FailedPrecondition(
"Vector variable '", name, "' does not divide into elements of size ",
sizeof(T));
}
*vector = Vector<T>(area.view(0));
if (dimensions.size() != 1) {
return tensorflow::errors::FailedPrecondition("Expected 1 dimensions, got ",
dimensions.size());
}
if (dimensions[0] != vector->size()) {
return tensorflow::errors::FailedPrecondition(
"Vector size (", vector->size(), ") disagrees with dimensions[0] (",
dimensions[0], ")");
}
return tensorflow::Status::OK();
}
template <class T>
tensorflow::Status VariableStore::Lookup(const string &name,
Matrix<T> *matrix) {
AlignedArea area;
std::vector<size_t> dimensions;
TF_RETURN_IF_ERROR(
Lookup(name, VariableSpec::FORMAT_ROW_MAJOR_MATRIX, &dimensions, &area));
if (dimensions.size() != 2) {
return tensorflow::errors::FailedPrecondition("Expected 2 dimensions, got ",
dimensions.size());
}
if (area.view_size() % sizeof(T) != 0) {
return tensorflow::errors::FailedPrecondition(
"Matrix variable '", name, "' does not divide into elements of size ",
sizeof(T));
}
*matrix = Matrix<T>(area);
if (dimensions[0] != matrix->num_rows()) {
return tensorflow::errors::FailedPrecondition(
"Matrix rows (", matrix->num_rows(), ") disagrees with dimensions[0] (",
dimensions[0], ")");
}
if (dimensions[1] != matrix->num_columns()) {
return tensorflow::errors::FailedPrecondition(
"Matrix columns (", matrix->num_columns(),
") disagrees with dimensions[1] (", dimensions[1], ")");
}
return tensorflow::Status::OK();
}
template <class T>
tensorflow::Status VariableStore::Lookup(const string &name,
BlockedMatrix<T> *matrix) {
AlignedArea area;
std::vector<size_t> dimensions;
TF_RETURN_IF_ERROR(
Lookup(name, VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX,
&dimensions, &area));
if (dimensions.size() != 3) {
return tensorflow::errors::FailedPrecondition("Expected 3 dimensions, got ",
dimensions.size());
}
const size_t num_rows = dimensions[0];
const size_t num_columns = dimensions[1];
const size_t block_size = dimensions[2];
if (area.view_size() != block_size * sizeof(T)) {
return tensorflow::errors::FailedPrecondition(
"Area view size (", area.view_size(),
") doesn't correspond to block size (", block_size,
") times data type size (", sizeof(T), ")");
}
if (num_rows * num_columns != area.num_views() * block_size) {
return tensorflow::errors::FailedPrecondition(
"Rows * cols (", num_rows * num_columns, ") != area view size (",
area.num_views() * block_size, ")");
}
// Avoid modification on error.
BlockedMatrix<T> local_matrix;
TF_RETURN_IF_ERROR(local_matrix.Reset(area, num_rows, num_columns));
*matrix = local_matrix;
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_VARIABLE_STORE_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment