Commit edea2b67 authored by Terry Koo's avatar Terry Koo
Browse files

Remove runtime because reasons.

parent a4bb31d0
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <algorithm>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/term_map_sequence_predictor.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Predicts sequences of POS tags in SyntaxNetComponent batches.
class SyntaxNetTagSequencePredictor : public TermMapSequencePredictor {
public:
SyntaxNetTagSequencePredictor();
// Implements SequencePredictor.
bool Supports(const ComponentSpec &component_spec) const override;
tensorflow::Status Initialize(const ComponentSpec &component_spec) override;
tensorflow::Status Predict(Matrix<float> logits,
InputBatchCache *input) const override;
private:
// Whether to process sequences from left to right.
bool left_to_right_ = true;
};
SyntaxNetTagSequencePredictor::SyntaxNetTagSequencePredictor()
: TermMapSequencePredictor("tag-map") {}
bool SyntaxNetTagSequencePredictor::Supports(
const ComponentSpec &component_spec) const {
return TermMapSequencePredictor::SupportsTermMap(component_spec) &&
component_spec.backend().registered_name() == "SyntaxNetComponent" &&
component_spec.transition_system().registered_name() == "tagger";
}
tensorflow::Status SyntaxNetTagSequencePredictor::Initialize(
const ComponentSpec &component_spec) {
// Load all tags.
constexpr int kMinFrequency = 0;
constexpr int kMaxNumTerms = 0;
TF_RETURN_IF_ERROR(TermMapSequencePredictor::InitializeTermMap(
component_spec, kMinFrequency, kMaxNumTerms));
if (term_map().Size() == 0) {
return tensorflow::errors::InvalidArgument("Empty tag map");
}
const int map_num_tags = term_map().Size();
const int spec_num_tags = component_spec.num_actions();
if (map_num_tags != spec_num_tags) {
return tensorflow::errors::InvalidArgument(
"Tag count mismatch between term map (", map_num_tags,
") and ComponentSpec (", spec_num_tags, ")");
}
left_to_right_ = TransitionSystemTraits(component_spec).is_left_to_right;
return tensorflow::Status::OK();
}
tensorflow::Status SyntaxNetTagSequencePredictor::Predict(
Matrix<float> logits, InputBatchCache *input) const {
if (logits.num_columns() != term_map().Size()) {
return tensorflow::errors::InvalidArgument(
"Logits shape mismatch: expected ", term_map().Size(),
" columns but got ", logits.num_columns());
}
const std::vector<SyntaxNetSentence> &data =
*input->GetAs<SentenceInputBatch>()->data();
if (data.size() != 1) {
return tensorflow::errors::InvalidArgument("Non-singleton batch: got ",
data.size(), " elements");
}
Sentence *sentence = data[0].sentence();
const int num_tokens = sentence->token_size();
if (logits.num_rows() != num_tokens) {
return tensorflow::errors::InvalidArgument(
"Logits shape mismatch: expected ", num_tokens, " rows but got ",
logits.num_rows());
}
int token_index = left_to_right_ ? 0 : num_tokens - 1;
const int token_increment = left_to_right_ ? 1 : -1;
for (int i = 0; i < num_tokens; ++i, token_index += token_increment) {
const Vector<float> row = logits.row(i);
Token *token = sentence->mutable_token(token_index);
const float *const begin = row.begin();
const float *const end = row.end();
token->set_tag(term_map().GetTerm(std::max_element(begin, end) - begin));
}
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_SEQUENCE_PREDICTOR(SyntaxNetTagSequencePredictor);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "dragnn/runtime/test/helpers.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr char kResourceName[] = "tag-map";
// Writes a default tag map and returns a path to it.
string GetTagMapPath() {
static string *const kPath =
new string(WriteTermMap({{"NOUN", 3}, {"VERB", 2}, {"DET", 1}}));
return *kPath;
}
// Returns a ComponentSpec parsed from the |text| that contains a term map
// resource pointing at the |path|.
ComponentSpec MakeSpec(const string &text, const string &path) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(text, &component_spec));
AddTermMapResource(kResourceName, path, &component_spec);
return component_spec;
}
// Returns a ComponentSpec that the predictor will support.
ComponentSpec MakeSupportedSpec() {
return MakeSpec(R"(transition_system { registered_name: 'tagger' }
backend { registered_name: 'SyntaxNetComponent' }
num_actions: 3)",
GetTagMapPath());
}
// Returns per-token tag logits.
UniqueMatrix<float> MakeLogits() {
return UniqueMatrix<float>({{0.0, 0.0, 1.0}, // predict 2 = DET
{1.0, 0.0, 0.0}, // predict 0 = NOUN
{0.0, 1.0, 0.0}, // predict 1 = VERB
{0.0, 0.0, 1.0}, // predict 2 = DET
{1.0, 0.0, 0.0}}); // predict 0 = NOUN
}
// Returns a default sentence.
Sentence MakeSentence() {
Sentence sentence;
for (const string &word : {"the", "cat", "chased", "a", "mouse"}) {
Token *token = sentence.add_token();
token->set_start(0); // never used; set because required field
token->set_end(0); // never used; set because required field
token->set_word(word);
}
return sentence;
}
// Tests that the predictor supports an appropriate spec.
TEST(SyntaxNetTagSequencePredictorTest, Supported) {
const ComponentSpec component_spec = MakeSupportedSpec();
string name;
TF_ASSERT_OK(SequencePredictor::Select(component_spec, &name));
EXPECT_EQ(name, "SyntaxNetTagSequencePredictor");
}
// Tests that the predictor requires the proper backend.
TEST(SyntaxNetTagSequencePredictorTest, WrongBackend) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_backend()->set_registered_name("bad");
string name;
EXPECT_THAT(
SequencePredictor::Select(component_spec, &name),
test::IsErrorWithSubstr("No SequencePredictor supports ComponentSpec"));
}
// Tests that the predictor requires the proper transition system.
TEST(SyntaxNetTagSequencePredictorTest, WrongTransitionSystem) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_transition_system()->set_registered_name("bad");
string name;
EXPECT_THAT(
SequencePredictor::Select(component_spec, &name),
test::IsErrorWithSubstr("No SequencePredictor supports ComponentSpec"));
}
// Tests that the predictor can be initialized and used to add POS tags to a
// sentence.
TEST(SyntaxNetTagSequencePredictorTest, InitializeAndPredict) {
const ComponentSpec component_spec = MakeSupportedSpec();
std::unique_ptr<SequencePredictor> predictor;
TF_ASSERT_OK(SequencePredictor::New("SyntaxNetTagSequencePredictor",
component_spec, &predictor));
UniqueMatrix<float> logits = MakeLogits();
const Sentence sentence = MakeSentence();
InputBatchCache input(sentence.SerializeAsString());
TF_ASSERT_OK(predictor->Predict(Matrix<float>(*logits), &input));
const std::vector<string> predictions = input.SerializedData();
ASSERT_EQ(predictions.size(), 1);
Sentence tagged;
ASSERT_TRUE(tagged.ParseFromString(predictions[0]));
ASSERT_EQ(tagged.token_size(), 5);
EXPECT_EQ(tagged.token(0).tag(), "DET"); // the
EXPECT_EQ(tagged.token(1).tag(), "NOUN"); // cat
EXPECT_EQ(tagged.token(2).tag(), "VERB"); // chased
EXPECT_EQ(tagged.token(3).tag(), "DET"); // a
EXPECT_EQ(tagged.token(4).tag(), "NOUN"); // mouse
}
// Tests that the predictor works on an empty sentence.
TEST(SyntaxNetTagSequencePredictorTest, EmptySentence) {
const ComponentSpec component_spec = MakeSupportedSpec();
std::unique_ptr<SequencePredictor> predictor;
TF_ASSERT_OK(SequencePredictor::New("SyntaxNetTagSequencePredictor",
component_spec, &predictor));
AlignedView view;
AlignedArea area;
TF_ASSERT_OK(area.Reset(view, 0, 3 * sizeof(float)));
Matrix<float> logits(area);
const Sentence sentence;
InputBatchCache input(sentence.SerializeAsString());
TF_ASSERT_OK(predictor->Predict(logits, &input));
const std::vector<string> predictions = input.SerializedData();
ASSERT_EQ(predictions.size(), 1);
Sentence tagged;
ASSERT_TRUE(tagged.ParseFromString(predictions[0]));
ASSERT_EQ(tagged.token_size(), 0);
}
// Tests that the predictor fails on an empty term map.
TEST(SyntaxNetTagSequencePredictorTest, EmptyTermMap) {
const string path = WriteTermMap({});
const ComponentSpec component_spec = MakeSpec("", path);
std::unique_ptr<SequencePredictor> predictor;
EXPECT_THAT(SequencePredictor::New("SyntaxNetTagSequencePredictor",
component_spec, &predictor),
test::IsErrorWithSubstr("Empty tag map"));
}
// Tests that Predict() fails if the batch is the wrong size.
TEST(SyntaxNetTagSequencePredictorTest, WrongBatchSize) {
const ComponentSpec component_spec = MakeSupportedSpec();
std::unique_ptr<SequencePredictor> predictor;
TF_ASSERT_OK(SequencePredictor::New("SyntaxNetTagSequencePredictor",
component_spec, &predictor));
UniqueMatrix<float> logits = MakeLogits();
const Sentence sentence = MakeSentence();
const std::vector<string> data = {sentence.SerializeAsString(),
sentence.SerializeAsString()};
InputBatchCache input(data);
EXPECT_THAT(predictor->Predict(Matrix<float>(*logits), &input),
test::IsErrorWithSubstr("Non-singleton batch: got 2 elements"));
}
// Tests that Initialize() fails if the term map doesn't match the specified
// number of actions.
TEST(SyntaxNetTagSequencePredictorTest, WrongNumActions) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.set_num_actions(1000);
std::unique_ptr<SequencePredictor> predictor;
EXPECT_THAT(
SequencePredictor::New("SyntaxNetTagSequencePredictor", component_spec,
&predictor),
test::IsErrorWithSubstr(
"Tag count mismatch between term map (3) and ComponentSpec (1000)"));
}
// Tests that Predict() fails if the logits don't match the term map.
TEST(SyntaxNetTagSequencePredictorTest, WrongLogitsColumns) {
const string path = WriteTermMap({{"a", 1}, {"b", 1}});
const ComponentSpec component_spec = MakeSpec("num_actions: 2", path);
std::unique_ptr<SequencePredictor> predictor;
TF_ASSERT_OK(SequencePredictor::New("SyntaxNetTagSequencePredictor",
component_spec, &predictor));
UniqueMatrix<float> logits = MakeLogits();
Sentence sentence = MakeSentence();
InputBatchCache input(sentence.SerializeAsString());
EXPECT_THAT(predictor->Predict(Matrix<float>(*logits), &input),
test::IsErrorWithSubstr(
"Logits shape mismatch: expected 2 columns but got 3"));
}
// Tests that Predict() fails if the logits don't match the number of tokens.
TEST(SyntaxNetTagSequencePredictorTest, WrongLogitsRows) {
const ComponentSpec component_spec = MakeSupportedSpec();
std::unique_ptr<SequencePredictor> predictor;
TF_ASSERT_OK(SequencePredictor::New("SyntaxNetTagSequencePredictor",
component_spec, &predictor));
UniqueMatrix<float> logits = MakeLogits();
Sentence sentence = MakeSentence();
sentence.mutable_token()->RemoveLast(); // bad
InputBatchCache input(sentence.SerializeAsString());
EXPECT_THAT(predictor->Predict(Matrix<float>(*logits), &input),
test::IsErrorWithSubstr(
"Logits shape mismatch: expected 4 rows but got 5"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/term_map_sequence_extractor.h"
#include "dragnn/runtime/term_map_utils.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Sequence extractor that extracts words from a SyntaxNetComponent batch.
class SyntaxNetWordSequenceExtractor
: public TermMapSequenceExtractor<TermFrequencyMap> {
public:
SyntaxNetWordSequenceExtractor();
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) const override;
tensorflow::Status Initialize(const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) override;
tensorflow::Status GetIds(InputBatchCache *input,
std::vector<int32> *ids) const override;
private:
// Parses |fml| and sets |min_frequency| and |max_num_terms| to the specified
// values. If the |fml| does not specify a supported feature, returns non-OK
// and modifies nothing.
static tensorflow::Status ParseFml(const string &fml, int *min_frequency,
int *max_num_terms);
// Feature ID for unknown words.
int32 unknown_id_ = -1;
};
SyntaxNetWordSequenceExtractor::SyntaxNetWordSequenceExtractor()
: TermMapSequenceExtractor("word-map") {}
tensorflow::Status SyntaxNetWordSequenceExtractor::ParseFml(
const string &fml, int *min_frequency, int *max_num_terms) {
return ParseTermMapFml(fml, {"input", "token", "word"}, min_frequency,
max_num_terms);
}
bool SyntaxNetWordSequenceExtractor::Supports(
const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) const {
TransitionSystemTraits traits(component_spec);
int unused_min_frequency = 0;
int unused_max_num_terms = 0;
const tensorflow::Status parse_fml_status =
ParseFml(channel.fml(), &unused_min_frequency, &unused_max_num_terms);
return TermMapSequenceExtractor::SupportsTermMap(channel, component_spec) &&
parse_fml_status.ok() &&
component_spec.backend().registered_name() == "SyntaxNetComponent" &&
traits.is_sequential && traits.is_token_scale;
}
tensorflow::Status SyntaxNetWordSequenceExtractor::Initialize(
const FixedFeatureChannel &channel, const ComponentSpec &component_spec) {
int min_frequency = 0;
int max_num_terms = 0;
TF_RETURN_IF_ERROR(ParseFml(channel.fml(), &min_frequency, &max_num_terms));
TF_RETURN_IF_ERROR(TermMapSequenceExtractor::InitializeTermMap(
channel, component_spec, min_frequency, max_num_terms));
unknown_id_ = term_map().Size();
const int outside_id = unknown_id_ + 1;
const int map_vocab_size = outside_id + 1;
const int spec_vocab_size = channel.vocabulary_size();
if (map_vocab_size != spec_vocab_size) {
return tensorflow::errors::InvalidArgument(
"Word vocabulary size mismatch between term map (", map_vocab_size,
") and ComponentSpec (", spec_vocab_size, ")");
}
return tensorflow::Status::OK();
}
tensorflow::Status SyntaxNetWordSequenceExtractor::GetIds(
InputBatchCache *input, std::vector<int32> *ids) const {
ids->clear();
const std::vector<SyntaxNetSentence> &data =
*input->GetAs<SentenceInputBatch>()->data();
if (data.size() != 1) {
return tensorflow::errors::InvalidArgument("Non-singleton batch: got ",
data.size(), " elements");
}
const Sentence &sentence = *data[0].sentence();
for (const Token &token : sentence.token()) {
ids->push_back(term_map().LookupIndex(token.word(), unknown_id_));
}
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(SyntaxNetWordSequenceExtractor);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr char kResourceName[] = "word-map";
// Returns a ComponentSpec parsed from the |text| that contains a term map
// resource pointing at the |path|.
ComponentSpec MakeSpec(const string &text, const string &path) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(text, &component_spec));
AddTermMapResource(kResourceName, path, &component_spec);
return component_spec;
}
// Returns a ComponentSpec that the extractor will support.
ComponentSpec MakeSupportedSpec() {
return MakeSpec(
R"(transition_system { registered_name: 'shift-only' }
backend { registered_name: 'SyntaxNetComponent' }
fixed_feature {} # breaks hard-coded refs to channel 0
fixed_feature { size: 1 fml: 'input.token.word(min-freq=2)' })",
"/dev/null");
}
// Returns a default sentence.
Sentence MakeSentence() {
Sentence sentence;
for (const string &word : {"a", "bc", "def"}) {
Token *token = sentence.add_token();
token->set_start(0); // never used; set because required field
token->set_end(0); // never used; set because required field
token->set_word(word);
}
return sentence;
}
// Tests that the extractor supports an appropriate spec.
TEST(SyntaxNetWordSequenceExtractorTest, Supported) {
const ComponentSpec component_spec = MakeSupportedSpec();
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
string name;
TF_ASSERT_OK(SequenceExtractor::Select(channel, component_spec, &name));
EXPECT_EQ(name, "SyntaxNetWordSequenceExtractor");
}
// Tests that the extractor requires the proper backend.
TEST(SyntaxNetWordSequenceExtractorTest, WrongBackend) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_backend()->set_registered_name("bad");
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
string name;
EXPECT_THAT(
SequenceExtractor::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceExtractor supports channel"));
}
// Tests that the extractor requires the proper transition system.
TEST(SyntaxNetWordSequenceExtractorTest, WrongTransitionSystem) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_transition_system()->set_registered_name("bad");
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
string name;
EXPECT_THAT(
SequenceExtractor::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceExtractor supports channel"));
}
// Expects that the |fml| is rejected by the extractor.
void ExpectRejectedFml(const string &fml) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_fixed_feature(1)->set_fml(fml);
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
string name;
EXPECT_THAT(
SequenceExtractor::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceExtractor supports channel"));
}
// Tests that the extractor requires the proper FML.
TEST(SyntaxNetWordSequenceExtractorTest, WrongFml) {
ExpectRejectedFml("bad");
EXPECT_DEATH(ExpectRejectedFml("input.token.word("),
"Error in feature model");
EXPECT_DEATH(ExpectRejectedFml("input.token.word()"),
"Error in feature model");
ExpectRejectedFml("input.token.word(10)");
EXPECT_DEATH(ExpectRejectedFml("input.token.word(min-freq=)"),
"Error in feature model");
EXPECT_DEATH(ExpectRejectedFml("input.token.word(min-freq=10"),
"Error in feature model");
ExpectRejectedFml("input.token.word(min-freq=ten)");
ExpectRejectedFml("input.token.word(min_freq=10)"); // underscore
}
// Tests that the extractor can be initialized and used to extract feature IDs.
TEST(SyntaxNetWordSequenceExtractorTest, InitializeAndGetIds) {
// Terms are sorted by descending frequency, so this ensures a=0, bc=1, etc.
// Note that "e" is too infrequent, so vocabulary_size=5 from 3 terms plus 2
// special values.
const string path = WriteTermMap({{"a", 5}, {"bc", 3}, {"d", 2}, {"e", 1}});
const ComponentSpec component_spec = MakeSpec(
"fixed_feature {} "
"fixed_feature { vocabulary_size:5 fml:'input.token.word(min-freq=2)' }",
path);
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
std::unique_ptr<SequenceExtractor> extractor;
TF_ASSERT_OK(SequenceExtractor::New("SyntaxNetWordSequenceExtractor", channel,
component_spec, &extractor));
const Sentence sentence = MakeSentence();
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> ids;
TF_ASSERT_OK(extractor->GetIds(&input, &ids));
const std::vector<int32> expected_ids = {0, 1, 3};
EXPECT_EQ(ids, expected_ids);
}
// Tests that an empty term map works.
TEST(SyntaxNetWordSequenceExtractorTest, EmptyTermMap) {
const string path = WriteTermMap({});
const ComponentSpec component_spec = MakeSpec(
"fixed_feature {} "
"fixed_feature { fml:'input.token.word' vocabulary_size:2 }",
path);
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
std::unique_ptr<SequenceExtractor> extractor;
TF_ASSERT_OK(SequenceExtractor::New("SyntaxNetWordSequenceExtractor", channel,
component_spec, &extractor));
const Sentence sentence = MakeSentence();
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> ids = {1, 2, 3, 4}; // should be overwritten
TF_ASSERT_OK(extractor->GetIds(&input, &ids));
const std::vector<int32> expected_ids = {0, 0, 0};
EXPECT_EQ(ids, expected_ids);
}
// Tests that GetIds() fails if the batch is the wrong size.
TEST(SyntaxNetWordSequenceExtractorTest, WrongBatchSize) {
const string path = WriteTermMap({});
const ComponentSpec component_spec = MakeSpec(
"fixed_feature {} "
"fixed_feature { fml:'input.token.word' vocabulary_size:2 }",
path);
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
std::unique_ptr<SequenceExtractor> extractor;
TF_ASSERT_OK(SequenceExtractor::New("SyntaxNetWordSequenceExtractor", channel,
component_spec, &extractor));
const Sentence sentence = MakeSentence();
const std::vector<string> data = {sentence.SerializeAsString(),
sentence.SerializeAsString()};
InputBatchCache input(data);
std::vector<int32> ids;
EXPECT_THAT(extractor->GetIds(&input, &ids),
test::IsErrorWithSubstr("Non-singleton batch: got 2 elements"));
}
// Tests that initialization fails if the vocabulary size does not match.
TEST(SyntaxNetWordSequenceExtractorTest, WrongVocabularySize) {
const string path = WriteTermMap({});
const ComponentSpec component_spec = MakeSpec(
"fixed_feature {} "
"fixed_feature { fml:'input.token.word' vocabulary_size:1000 }",
path);
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
std::unique_ptr<SequenceExtractor> extractor;
EXPECT_THAT(
SequenceExtractor::New("SyntaxNetWordSequenceExtractor", channel,
component_spec, &extractor),
test::IsErrorWithSubstr("Word vocabulary size mismatch between term "
"map (2) and ComponentSpec (1000)"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TERM_MAP_SEQUENCE_EXTRACTOR_H_
#define DRAGNN_RUNTIME_TERM_MAP_SEQUENCE_EXTRACTOR_H_
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/term_map_utils.h"
#include "syntaxnet/base.h"
#include "syntaxnet/shared_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Base class for TermFrequencyMap-based sequence feature extractors. Requires
// the component to have a single fixed feature and a TermFrequencyMap resource.
// Templated on a |TermMap| type, which should have a 3-arg constructor similar
// to TermFrequencyMap's.
template <class TermMap>
class TermMapSequenceExtractor : public SequenceExtractor {
public:
// Creates a sequence extractor that will load a term map from the resource
// named |resource_name|.
explicit TermMapSequenceExtractor(const string &resource_name);
~TermMapSequenceExtractor() override;
// Returns true if the |channel| of the |component_spec| is compatible with
// this. Subclasses should call this from their Supports().
bool SupportsTermMap(const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) const;
// Loads a term map from the |channel| of the |component_spec|, applying the
// |min_frequency| and |max_num_terms| when loading the term map. On error,
// returns non-OK. Subclasses should call this from their Initialize().
tensorflow::Status InitializeTermMap(const FixedFeatureChannel &channel,
const ComponentSpec &component_spec,
int min_frequency, int max_num_terms);
protected:
// Returns the current term map. Only valid after InitializeTermMap().
const TermMap &term_map() const { return *term_map_; }
private:
// Name of the resouce from which to load a term map.
const string resource_name_;
// Mapping from terms to feature IDs. Owned by SharedStore.
const TermMap *term_map_ = nullptr;
};
// Implementation details below.
template <class TermMap>
TermMapSequenceExtractor<TermMap>::TermMapSequenceExtractor(
const string &resource_name)
: resource_name_(resource_name) {}
template <class TermMap>
TermMapSequenceExtractor<TermMap>::~TermMapSequenceExtractor() {
if (!SharedStore::Release(term_map_)) {
LOG(ERROR) << "Failed to release term map for resource " << resource_name_;
}
}
template <class TermMap>
bool TermMapSequenceExtractor<TermMap>::SupportsTermMap(
const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) const {
return LookupTermMapResourcePath(resource_name_, component_spec) != nullptr &&
channel.size() == 1;
}
template <class TermMap>
tensorflow::Status TermMapSequenceExtractor<TermMap>::InitializeTermMap(
const FixedFeatureChannel &channel, const ComponentSpec &component_spec,
int min_frequency, int max_num_terms) {
const string *path =
LookupTermMapResourcePath(resource_name_, component_spec);
if (path == nullptr) {
return tensorflow::errors::InvalidArgument(
"No compatible resource named '", resource_name_,
"' in ComponentSpec: ", component_spec.ShortDebugString());
}
term_map_ = SharedStoreUtils::GetWithDefaultName<TermMap>(
*path, min_frequency, max_num_terms);
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_TERM_MAP_SEQUENCE_EXTRACTOR_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/term_map_sequence_extractor.h"
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr char kResourceName[] = "term-map";
constexpr int kMinFrequency = 2;
constexpr int kMaxNumTerms = 0; // no limit
// A subclass for tests.
class BasicTermMapSequenceExtractor
: public TermMapSequenceExtractor<TermFrequencyMap> {
public:
BasicTermMapSequenceExtractor() : TermMapSequenceExtractor(kResourceName) {}
// Implements SequenceExtractor. These methods are never called, but must be
// defined so we can instantiate the class.
bool Supports(const FixedFeatureChannel &,
const ComponentSpec &) const override {
return true;
}
tensorflow::Status Initialize(const FixedFeatureChannel &,
const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status GetIds(InputBatchCache *,
std::vector<int32> *) const override {
return tensorflow::Status::OK();
}
// Publicizes the TermFrequencyMap accessor.
using TermMapSequenceExtractor::term_map;
};
// Returns a FixedFeatureChannel parsed from the |text|.
FixedFeatureChannel MakeChannel(const string &text) {
FixedFeatureChannel channel;
CHECK(TextFormat::ParseFromString(text, &channel));
return channel;
}
// Returns a ComponentSpec that contains a term map resource pointing at the
// |path|.
ComponentSpec MakeSpec(const string &path) {
ComponentSpec component_spec;
AddTermMapResource(kResourceName, path, &component_spec);
return component_spec;
}
// Tests that a term map can be successfully read.
TEST(TermMapSequenceExtractorTest, NormalOperation) {
const string path = WriteTermMap({{"too-infrequent", kMinFrequency - 1},
{"hello", kMinFrequency},
{"world", kMinFrequency + 1}});
const FixedFeatureChannel channel = MakeChannel("size:1");
const ComponentSpec spec = MakeSpec(path);
BasicTermMapSequenceExtractor extractor;
ASSERT_TRUE(extractor.SupportsTermMap(channel, spec));
TF_ASSERT_OK(
extractor.InitializeTermMap(channel, spec, kMinFrequency, kMaxNumTerms));
// NB: Terms are sorted by frequency.
EXPECT_EQ(extractor.term_map().Size(), 2);
EXPECT_EQ(extractor.term_map().LookupIndex("hello", -1), 1);
EXPECT_EQ(extractor.term_map().LookupIndex("world", -1), 0);
EXPECT_EQ(extractor.term_map().LookupIndex("unknown", -1), -1);
}
// Tests that SupportsTermMap() requires the fixed feature channel to have
// size 1.
TEST(TermMapSequenceExtractorTest, FixedFeatureSize) {
const BasicTermMapSequenceExtractor extractor;
ASSERT_TRUE(
extractor.SupportsTermMap(MakeChannel("size:1"), MakeSpec("/dev/null")));
EXPECT_FALSE(
extractor.SupportsTermMap(MakeChannel("size:0"), MakeSpec("/dev/null")));
EXPECT_FALSE(
extractor.SupportsTermMap(MakeChannel("size:2"), MakeSpec("/dev/null")));
}
// Tests that SupportsTermMap() requires a resource with the proper name.
TEST(TermMapSequenceExtractorTest, ResourceName) {
const BasicTermMapSequenceExtractor extractor;
const FixedFeatureChannel channel = MakeChannel("size:1");
ComponentSpec spec = MakeSpec("/dev/null");
ASSERT_TRUE(extractor.SupportsTermMap(channel, spec));
spec.mutable_resource(0)->set_name("whatever");
EXPECT_FALSE(extractor.SupportsTermMap(channel, spec));
}
// Tests that InitializeTermMap() fails if the term map cannot be found.
TEST(TermMapSequenceExtractorTest, InitializeWithNoTermMap) {
BasicTermMapSequenceExtractor extractor;
const FixedFeatureChannel channel;
const ComponentSpec spec;
EXPECT_THAT(
extractor.InitializeTermMap(channel, spec, kMinFrequency, kMaxNumTerms),
test::IsErrorWithSubstr("No compatible resource"));
}
// Tests that InitializeTermMap() requires a proper term map file.
TEST(TermMapSequenceExtractorTest, InvalidPath) {
BasicTermMapSequenceExtractor extractor;
const FixedFeatureChannel channel = MakeChannel("size:1");
const ComponentSpec spec = MakeSpec("/some/bad/path");
ASSERT_TRUE(extractor.SupportsTermMap(channel, spec));
EXPECT_DEATH(
extractor.InitializeTermMap(channel, spec, kMinFrequency, kMaxNumTerms)
.IgnoreError(),
"/some/bad/path");
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/term_map_sequence_predictor.h"
#include "dragnn/runtime/term_map_utils.h"
#include "syntaxnet/shared_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
TermMapSequencePredictor::TermMapSequencePredictor(const string &resource_name)
: resource_name_(resource_name) {}
TermMapSequencePredictor::~TermMapSequencePredictor() {
if (!SharedStore::Release(term_map_)) {
LOG(ERROR) << "Failed to release term map for resource " << resource_name_;
}
}
bool TermMapSequencePredictor::SupportsTermMap(
const ComponentSpec &component_spec) const {
return LookupTermMapResourcePath(resource_name_, component_spec) != nullptr;
}
tensorflow::Status TermMapSequencePredictor::InitializeTermMap(
const ComponentSpec &component_spec, int min_frequency, int max_num_terms) {
const string *path =
LookupTermMapResourcePath(resource_name_, component_spec);
if (path == nullptr) {
return tensorflow::errors::InvalidArgument(
"No compatible resource named '", resource_name_,
"' in ComponentSpec: ", component_spec.ShortDebugString());
}
term_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
*path, min_frequency, max_num_terms);
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TERM_MAP_SEQUENCE_PREDICTOR_H_
#define DRAGNN_RUNTIME_TERM_MAP_SEQUENCE_PREDICTOR_H_
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_predictor.h"
#include "syntaxnet/base.h"
#include "syntaxnet/term_frequency_map.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Base class for predictors whose output label set is defined by a term map.
// Requires the component to have a TermFrequencyMap resource.
class TermMapSequencePredictor : public SequencePredictor {
public:
// Creates a sequence predictor that will load a term map from the resource
// named |resource_name|.
explicit TermMapSequencePredictor(const string &resource_name);
~TermMapSequencePredictor() override;
// Returns true if the |component_spec| is compatible with this. Subclasses
// should call this from their Supports().
bool SupportsTermMap(const ComponentSpec &component_spec) const;
// Loads a term map from the |component_spec|, applying the |min_frequency|
// and |max_num_terms| when loading the term map. On error, returns non-OK.
// Subclasses should call this from their Initialize().
tensorflow::Status InitializeTermMap(const ComponentSpec &component_spec,
int min_frequency, int max_num_terms);
protected:
// Returns the current term map. Only valid after InitializeTermMap().
const TermFrequencyMap &term_map() const { return *term_map_; }
private:
// Name of the resouce from which to load a term map.
const string resource_name_;
// Mapping from strings to feature IDs. Owned by SharedStore.
const TermFrequencyMap *term_map_ = nullptr;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_TERM_MAP_SEQUENCE_PREDICTOR_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/term_map_sequence_predictor.h"
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr char kResourceName[] = "term-map";
constexpr int kMinFrequency = 2;
constexpr int kMaxNumTerms = 0; // no limit
// A subclass for tests.
class BasicTermMapSequencePredictor : public TermMapSequencePredictor {
public:
BasicTermMapSequencePredictor() : TermMapSequencePredictor(kResourceName) {}
// Implements SequencePredictor. These methods are never called, but must be
// defined so we can instantiate the class.
bool Supports(const ComponentSpec &) const override { return true; }
tensorflow::Status Initialize(const ComponentSpec &) override {
return tensorflow::Status::OK();
}
tensorflow::Status Predict(Matrix<float>, InputBatchCache *) const override {
return tensorflow::Status::OK();
}
// Publicizes the TermFrequencyMap accessor.
using TermMapSequencePredictor::term_map;
};
// Returns a ComponentSpec that contains a term map resource pointing at the
// |path|.
ComponentSpec MakeSpec(const string &path) {
ComponentSpec component_spec;
AddTermMapResource(kResourceName, path, &component_spec);
return component_spec;
}
// Tests that a term map can be successfully read.
TEST(TermMapSequencePredictorTest, NormalOperation) {
const string path = WriteTermMap({{"too-infrequent", kMinFrequency - 1},
{"hello", kMinFrequency},
{"world", kMinFrequency + 1}});
const ComponentSpec spec = MakeSpec(path);
BasicTermMapSequencePredictor predictor;
ASSERT_TRUE(predictor.SupportsTermMap(spec));
TF_ASSERT_OK(predictor.InitializeTermMap(spec, kMinFrequency, kMaxNumTerms));
// NB: Terms are sorted by frequency.
EXPECT_EQ(predictor.term_map().Size(), 2);
EXPECT_EQ(predictor.term_map().LookupIndex("hello", -1), 1);
EXPECT_EQ(predictor.term_map().LookupIndex("world", -1), 0);
EXPECT_EQ(predictor.term_map().LookupIndex("unknown", -1), -1);
}
// Tests that SupportsTermMap() requires a resource with the proper name.
TEST(TermMapSequencePredictorTest, ResourceName) {
const BasicTermMapSequencePredictor predictor;
ComponentSpec spec = MakeSpec("/dev/null");
ASSERT_TRUE(predictor.SupportsTermMap(spec));
spec.mutable_resource(0)->set_name("whatever");
EXPECT_FALSE(predictor.SupportsTermMap(spec));
}
// Tests that InitializeTermMap() fails if the term map cannot be found.
TEST(TermMapSequencePredictorTest, InitializeWithNoTermMap) {
BasicTermMapSequencePredictor predictor;
const ComponentSpec spec;
EXPECT_THAT(predictor.InitializeTermMap(spec, kMinFrequency, kMaxNumTerms),
test::IsErrorWithSubstr("No compatible resource"));
}
// Tests that InitializeTermMap() requires a proper term map file.
TEST(TermMapSequencePredictorTest, InvalidPath) {
BasicTermMapSequencePredictor predictor;
const ComponentSpec spec = MakeSpec("/some/bad/path");
ASSERT_TRUE(predictor.SupportsTermMap(spec));
EXPECT_DEATH(predictor.InitializeTermMap(spec, kMinFrequency, kMaxNumTerms)
.IgnoreError(),
"/some/bad/path");
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/term_map_utils.h"
#include "dragnn/runtime/fml_parsing.h"
#include "syntaxnet/feature_extractor.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Attributes for extracting term map feature options
struct TermMapAttributes : public FeatureFunctionAttributes {
// Minimum frequency for included terms.
Optional<int32> min_frequency{"min-freq", 0, this};
// Maximum number of terms to include.
Optional<int32> max_num_terms{"max-num-terms", 0, this};
};
// Returns true if the |record_format| is compatible with a TermFrequencyMap.
bool CompatibleRecordFormat(const string &record_format) {
return record_format.empty() || record_format == "TermFrequencyMap";
}
} // namespace
const string *LookupTermMapResourcePath(const string &resource_name,
const ComponentSpec &component_spec) {
for (const Resource &resource : component_spec.resource()) {
if (resource.name() != resource_name) continue;
if (resource.part_size() != 1) continue;
const Part &part = resource.part(0);
if (part.file_format() != "text") continue;
if (!CompatibleRecordFormat(part.record_format())) continue;
return &part.file_pattern();
}
return nullptr;
}
tensorflow::Status ParseTermMapFml(const string &fml,
const std::vector<string> &types,
int *min_frequency, int *max_num_terms) {
FeatureFunctionDescriptor function;
TF_RETURN_IF_ERROR(ParseFeatureChainFml(fml, types, &function));
if (function.argument() != 0) {
return tensorflow::errors::InvalidArgument(
"TermFrequencyMap-based feature should have no argument: ", fml);
}
TermMapAttributes attributes;
TF_RETURN_IF_ERROR(attributes.Reset(function));
// Success; make modifications.
*min_frequency = attributes.min_frequency();
*max_num_terms = attributes.max_num_terms();
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TERM_MAP_UTILS_H_
#define DRAGNN_RUNTIME_TERM_MAP_UTILS_H_
#include <string>
#include <vector>
#include "dragnn/protos/spec.pb.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Returns the path to the TermFrequencyMap resource named |resource_name| in
// the |component_spec|, or null if not found.
const string *LookupTermMapResourcePath(const string &resource_name,
const ComponentSpec &component_spec);
// Parses the |fml| as a chain of |types| ending in a TermFrequencyMap-based
// feature with "min-freq" and "max-num-terms" options. Sets |min_frequency|
// and |max_num_terms| to the option values. On error, returns non-OK and
// modifies nothing.
tensorflow::Status ParseTermMapFml(const string &fml,
const std::vector<string> &types,
int *min_frequency, int *max_num_terms);
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_TERM_MAP_UTILS_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/term_map_utils.h"
#include <string>
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr char kResourceName[] = "term-map";
constexpr char kResourcePath[] = "/path/to/term-map";
// Returns a ComponentSpec with a term map resource named |kResourceName| that
// points at |kResourcePath|.
ComponentSpec MakeSpec() {
ComponentSpec spec;
AddTermMapResource(kResourceName, kResourcePath, &spec);
return spec;
}
// Tests that a term map resource can be successfully read.
TEST(LookupTermMapResourcePathTest, Success) {
const ComponentSpec spec = MakeSpec();
const string *path = LookupTermMapResourcePath(kResourceName, spec);
ASSERT_NE(path, nullptr);
EXPECT_EQ(*path, kResourcePath);
}
// Tests that the returned path is null for an empty spec.
TEST(LookupTermMapResourcePathTest, EmptySpec) {
const ComponentSpec spec;
EXPECT_EQ(LookupTermMapResourcePath(kResourceName, spec), nullptr);
}
// Tests that the returned path is null for the wrong resource name.
TEST(LookupTermMapResourcePathTest, WrongName) {
ComponentSpec spec = MakeSpec();
spec.mutable_resource(0)->set_name("bad");
EXPECT_EQ(LookupTermMapResourcePath(kResourceName, spec), nullptr);
}
// Tests that the returned path is null for the wrong number of parts.
TEST(LookupTermMapResourcePathTest, WrongNumberOfParts) {
ComponentSpec spec = MakeSpec();
spec.mutable_resource(0)->clear_part();
EXPECT_EQ(LookupTermMapResourcePath(kResourceName, spec), nullptr);
spec.mutable_resource(0)->add_part();
spec.mutable_resource(0)->add_part();
EXPECT_EQ(LookupTermMapResourcePath(kResourceName, spec), nullptr);
}
// Tests that the returned path is null for the wrong file format.
TEST(LookupTermMapResourcePathTest, WrongFileFormat) {
ComponentSpec spec = MakeSpec();
spec.mutable_resource(0)->mutable_part(0)->set_file_format("bad");
EXPECT_EQ(LookupTermMapResourcePath(kResourceName, spec), nullptr);
}
// Tests that the returned path is null for the wrong record format.
TEST(LookupTermMapResourcePathTest, WrongRecordFormat) {
ComponentSpec spec = MakeSpec();
spec.mutable_resource(0)->mutable_part(0)->set_record_format("bad");
EXPECT_EQ(LookupTermMapResourcePath(kResourceName, spec), nullptr);
}
// Tests that alternate record formats are accepted.
TEST(LookupTermMapResourcePathTest, SuccessWithAlternateRecordFormat) {
ComponentSpec spec = MakeSpec();
spec.mutable_resource(0)->mutable_part(0)->set_record_format(
"TermFrequencyMap");
const string *path = LookupTermMapResourcePath(kResourceName, spec);
ASSERT_NE(path, nullptr);
EXPECT_EQ(*path, kResourcePath);
}
// Tests that ParseTermMapFml() correctly parses term map feature options.
TEST(ParseTermMapFmlTest, Success) {
int min_frequency = -1;
int max_num_terms = -1;
TF_ASSERT_OK(ParseTermMapFml("path.to.foo", {"path", "to", "foo"},
&min_frequency, &max_num_terms));
EXPECT_EQ(min_frequency, 0);
EXPECT_EQ(max_num_terms, 0);
TF_ASSERT_OK(ParseTermMapFml("path.to.foo(min-freq=5)", {"path", "to", "foo"},
&min_frequency, &max_num_terms));
EXPECT_EQ(min_frequency, 5);
EXPECT_EQ(max_num_terms, 0);
TF_ASSERT_OK(ParseTermMapFml("path.to.foo(max-num-terms=1000)",
{"path", "to", "foo"}, &min_frequency,
&max_num_terms));
EXPECT_EQ(min_frequency, 0);
EXPECT_EQ(max_num_terms, 1000);
TF_ASSERT_OK(ParseTermMapFml("path.to.foo(min-freq=12,max-num-terms=3456)",
{"path", "to", "foo"}, &min_frequency,
&max_num_terms));
EXPECT_EQ(min_frequency, 12);
EXPECT_EQ(max_num_terms, 3456);
}
// Tests that ParseTermMapFml() tolerates a zero argument.
TEST(ParseTermMapFmlTest, SuccessWithZeroArgument) {
int min_frequency = -1;
int max_num_terms = -1;
TF_ASSERT_OK(ParseTermMapFml("path.to.foo(0)", {"path", "to", "foo"},
&min_frequency, &max_num_terms));
EXPECT_EQ(min_frequency, 0);
EXPECT_EQ(max_num_terms, 0);
TF_ASSERT_OK(ParseTermMapFml("path.to.foo(0,min-freq=5)",
{"path", "to", "foo"}, &min_frequency,
&max_num_terms));
EXPECT_EQ(min_frequency, 5);
EXPECT_EQ(max_num_terms, 0);
TF_ASSERT_OK(ParseTermMapFml("path.to.foo(0,max-num-terms=1000)",
{"path", "to", "foo"}, &min_frequency,
&max_num_terms));
EXPECT_EQ(min_frequency, 0);
EXPECT_EQ(max_num_terms, 1000);
TF_ASSERT_OK(ParseTermMapFml("path.to.foo(0,min-freq=12,max-num-terms=3456)",
{"path", "to", "foo"}, &min_frequency,
&max_num_terms));
EXPECT_EQ(min_frequency, 12);
EXPECT_EQ(max_num_terms, 3456);
}
// Tests that ParseTermMapFml() fails on a non-zero argument.
TEST(ParseTermMapFmlTest, NonZeroArgument) {
int min_frequency = -1;
int max_num_terms = -1;
EXPECT_THAT(ParseTermMapFml("path.to.foo(1)", {"path", "to", "foo"},
&min_frequency, &max_num_terms),
test::IsErrorWithSubstr(
"TermFrequencyMap-based feature should have no argument"));
EXPECT_EQ(min_frequency, -1);
EXPECT_EQ(max_num_terms, -1);
}
// Tests that ParseTermMapFml() fails on an unknown feature option.
TEST(ParseTermMapFmlTest, UnknownOption) {
int min_frequency = -1;
int max_num_terms = -1;
EXPECT_THAT(ParseTermMapFml("path.to.foo(unknown=1)", {"path", "to", "foo"},
&min_frequency, &max_num_terms),
test::IsErrorWithSubstr("Unknown attribute"));
EXPECT_EQ(min_frequency, -1);
EXPECT_EQ(max_num_terms, -1);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
package(
default_visibility = ["//visibility:public"],
)
cc_library(
name = "helpers",
testonly = 1,
srcs = ["helpers.cc"],
hdrs = ["helpers.h"],
deps = [
"//dragnn/runtime:alignment",
"//dragnn/runtime/math:avx_vector_array",
"//dragnn/runtime/math:sgemvv",
"//dragnn/runtime/math:transformations",
"//dragnn/runtime/math:types",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "helpers_test",
size = "small",
srcs = ["helpers_test.cc"],
deps = [
":helpers",
"//dragnn/runtime:alignment",
"//dragnn/runtime/math:types",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "fake_variable_store",
testonly = 1,
srcs = ["fake_variable_store.cc"],
hdrs = ["fake_variable_store.h"],
deps = [
":helpers",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime:alignment",
"//dragnn/runtime:variable_store",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "fake_variable_store_test",
size = "small",
srcs = ["fake_variable_store_test.cc"],
deps = [
":fake_variable_store",
"//dragnn/core/test:generic",
"//dragnn/runtime:alignment",
"//dragnn/runtime/math:types",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "network_test_base",
testonly = 1,
srcs = ["network_test_base.cc"],
hdrs = ["network_test_base.h"],
deps = [
":fake_variable_store",
"//dragnn/core/test:mock_compute_session",
"//dragnn/protos:data_proto_cc",
"//dragnn/runtime:extensions",
"//dragnn/runtime:flexible_matrix_kernel",
"//dragnn/runtime:network_states",
"//dragnn/runtime:session_state",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "term_map_helpers",
testonly = 1,
srcs = ["term_map_helpers.cc"],
hdrs = ["term_map_helpers.h"],
deps = [
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_test(
name = "term_map_helpers_test",
size = "small",
srcs = ["term_map_helpers_test.cc"],
deps = [
":term_map_helpers",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:term_frequency_map",
"@org_tensorflow//tensorflow/core:test",
],
)
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/test/fake_variable_store.h"
#include <string.h>
#include <utility>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
void FakeVariableStore::AddOrDie(const string &name,
const std::vector<std::vector<float>> &data,
VariableSpec::Format format) {
CHECK(variables_[name].empty()) << "Adding duplicate variable: " << name;
FormatMap formats;
// Add a flattened version.
std::vector<std::vector<float>> flat(1);
for (const auto &row : data) {
for (const float value : row) flat[0].push_back(value);
}
formats[VariableSpec::FORMAT_FLAT] = Variable(flat);
// Add the |data| in its natural row-major format.
formats[VariableSpec::FORMAT_ROW_MAJOR_MATRIX] = Variable(data);
// Add the |data| as a trivial blocked matrix with one block---i.e., block
// size equal to the number of columns. Conveniently, this matrix has the
// same underlying data layout as a plain matrix.
formats[VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX] =
Variable(data);
// If |format| is FORMAT_UNKNOWN, keep all formats. Otherwise, only keep the
// specified format.
if (format == VariableSpec::FORMAT_UNKNOWN) {
variables_[name] = std::move(formats);
} else {
variables_[name][format] = std::move(formats[format]);
}
}
void FakeVariableStore::SetBlockedDimensionOverride(
const string &name, const std::vector<size_t> &dimensions) {
override_blocked_dimensions_[name] = dimensions;
}
tensorflow::Status FakeVariableStore::Lookup(const string &name,
VariableSpec::Format format,
std::vector<size_t> *dimensions,
AlignedArea *area) {
const auto it = variables_.find(name);
if (it == variables_.end()) {
return tensorflow::errors::InvalidArgument("Unknown variable: ", name);
}
FormatMap &formats = it->second;
if (formats.find(format) == formats.end()) {
return tensorflow::errors::InvalidArgument("Unknown variable: ", name);
}
Variable &variable = formats.at(format);
dimensions->clear();
switch (format) {
case VariableSpec::FORMAT_UNKNOWN:
// This case should not happen because the |formats| mapping never has
// FORMAT_UNKNOWN as a key.
LOG(FATAL) << "Tried to get a variable with FORMAT_UNKNOWN";
case VariableSpec::FORMAT_FLAT:
*dimensions = {variable->num_columns()};
break;
case VariableSpec::FORMAT_ROW_MAJOR_MATRIX:
*dimensions = {variable->num_rows(), variable->num_columns()};
break;
case VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX:
if (override_blocked_dimensions_.find(name) !=
override_blocked_dimensions_.end()) {
*dimensions = override_blocked_dimensions_[name];
} else {
*dimensions = {variable->num_rows(), variable->num_columns(),
variable->num_columns()}; // = block_size
}
break;
}
*area = variable.area();
return tensorflow::Status::OK();
}
// Executes cleanup functions (see `cleanup_` comment).
SimpleFakeVariableStore::~SimpleFakeVariableStore() {
for (const auto &fcn : cleanup_) {
fcn();
}
}
tensorflow::Status SimpleFakeVariableStore::Lookup(
const string &name, VariableSpec::Format format,
std::vector<size_t> *dimensions, AlignedArea *area) {
// Test should call MockLookup() first.
CHECK(dimensions_to_return_ != nullptr);
CHECK(area_to_return_ != nullptr);
*dimensions = *dimensions_to_return_;
*area = *area_to_return_;
dimensions_to_return_ = nullptr;
area_to_return_ = nullptr;
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_TEST_FAKE_VARIABLE_STORE_H_
#define DRAGNN_RUNTIME_TEST_FAKE_VARIABLE_STORE_H_
#include <map>
#include <string>
#include <vector>
#include "dragnn/protos/runtime.pb.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/test/helpers.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A fake variable store with user-specified contents.
class FakeVariableStore : public VariableStore {
public:
// Creates an empty store.
FakeVariableStore() = default;
// Adds the |data| to this as a variable with the |name| and |format|. If the
// |format| is FORMAT_UNKNOWN, adds the data in all formats. On error, aborts
// the program.
void AddOrDie(const string &name, const std::vector<std::vector<float>> &data,
VariableSpec::Format format = VariableSpec::FORMAT_UNKNOWN);
// Overrides the default behavior of assuming that there is one block along
// the major axis of the matrix.
void SetBlockedDimensionOverride(const string &name,
const std::vector<size_t> &dimensions);
// Implements VariableStore.
using VariableStore::Lookup; // import Lookup<T>() convenience methods
tensorflow::Status Lookup(const string &name, VariableSpec::Format format,
std::vector<size_t> *dimensions,
AlignedArea *area) override;
tensorflow::Status Close() override { return tensorflow::Status::OK(); }
private:
using Variable = UniqueMatrix<float>;
using FormatMap = std::map<VariableSpec::Format, Variable>;
// Mappings from variable name to format to contents.
std::map<string, FormatMap> variables_;
// Overrides blocked dimensions.
std::map<string, std::vector<size_t>> override_blocked_dimensions_;
};
// Syntactic sugar for replicating data to SimpleFakeVariableStore::MockLookup.
template <typename T>
std::vector<std::vector<T>> ReplicateRows(std::vector<T> values, int times) {
return std::vector<std::vector<T>>(times, values);
}
// Simpler fake variable store, where the test just sets up the next value to be
// returned.
class SimpleFakeVariableStore : public VariableStore {
public:
// Executes cleanup functions (see `cleanup_` comment).
~SimpleFakeVariableStore() override;
// Sets values which store().Lookup() will return.
template <typename T>
void MockLookup(const std::vector<size_t> &dimensions,
const std::vector<std::vector<T>> &area_values) {
UniqueMatrix<T> *matrix = new UniqueMatrix<T>(area_values);
cleanup_.push_back([matrix]() { delete matrix; });
dimensions_to_return_.reset(new std::vector<size_t>(dimensions));
area_to_return_.reset(new AlignedArea(matrix->area()));
}
using VariableStore::Lookup; // import Lookup<T>() convenience methods
tensorflow::Status Lookup(const string &name, VariableSpec::Format format,
std::vector<size_t> *dimensions,
AlignedArea *area) override;
tensorflow::Status Close() override { return tensorflow::Status::OK(); }
private:
std::unique_ptr<std::vector<size_t>> dimensions_to_return_ = nullptr;
std::unique_ptr<AlignedArea> area_to_return_ = nullptr;
// Functions which will delete memory storing mocked arrays. We want to keep
// the memory accessible until the end of the test. We also can't keep an
// array of objects to delete, since they are of different types.
std::vector<std::function<void()>> cleanup_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_TEST_FAKE_VARIABLE_STORE_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/test/fake_variable_store.h"
#include <stddef.h>
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns a data matrix that has no alignment padding. This is required for
// BlockedMatrix, which does not tolerate alignment padding. The contents of
// the returned matrix are [0.0, 1.0, 2.0, ...] in the natural order.
std::vector<std::vector<float>> MakeBlockedData() {
const size_t kNumRows = 18;
const size_t kNumColumns = internal::kAlignmentBytes / sizeof(float);
std::vector<std::vector<float>> data(kNumRows);
float counter = 0.0;
for (std::vector<float> &row : data) {
row.resize(kNumColumns);
for (float &value : row) value = counter++;
}
return data;
}
// Tests that Lookup*() behaves properly w.r.t. AddOrDie().
TEST(FakeVariableStoreTest, Lookup) {
FakeVariableStore store;
AlignedView view;
Vector<float> vector;
Matrix<float> matrix;
BlockedMatrix<float> blocked_matrix;
// Fail to look up an unknown name.
EXPECT_THAT(store.Lookup("foo", &vector),
test::IsErrorWithSubstr("Unknown variable"));
EXPECT_TRUE(view.empty()); // not modified
// Add some data and try looking it up.
store.AddOrDie("foo", {{1.0, 2.0, 3.0}});
TF_EXPECT_OK(store.Lookup("foo", &vector));
ASSERT_EQ(vector.size(), 3);
EXPECT_EQ(vector[0], 1.0);
EXPECT_EQ(vector[1], 2.0);
EXPECT_EQ(vector[2], 3.0);
TF_EXPECT_OK(store.Lookup("foo", &matrix));
ASSERT_EQ(matrix.num_rows(), 1);
ASSERT_EQ(matrix.num_columns(), 3);
EXPECT_EQ(matrix.row(0)[0], 1.0);
EXPECT_EQ(matrix.row(0)[1], 2.0);
EXPECT_EQ(matrix.row(0)[2], 3.0);
// Try a funny name.
store.AddOrDie("", {{5.0, 7.0}, {11.0, 13.0}});
TF_EXPECT_OK(store.Lookup("", &vector));
ASSERT_EQ(vector.size(), 4);
EXPECT_EQ(vector[0], 5.0);
EXPECT_EQ(vector[1], 7.0);
EXPECT_EQ(vector[2], 11.0);
EXPECT_EQ(vector[3], 13.0);
TF_EXPECT_OK(store.Lookup("", &matrix));
ASSERT_EQ(matrix.num_rows(), 2);
ASSERT_EQ(matrix.num_columns(), 2);
EXPECT_EQ(matrix.row(0)[0], 5.0);
EXPECT_EQ(matrix.row(0)[1], 7.0);
EXPECT_EQ(matrix.row(1)[0], 11.0);
EXPECT_EQ(matrix.row(1)[1], 13.0);
// Try blocked matrices. These must not have alignment padding.
const auto blocked_data = MakeBlockedData();
store.AddOrDie("blocked", blocked_data);
TF_ASSERT_OK(store.Lookup("blocked", &blocked_matrix));
ASSERT_EQ(blocked_matrix.num_rows(), blocked_data.size());
ASSERT_EQ(blocked_matrix.num_columns(), blocked_data[0].size());
ASSERT_EQ(blocked_matrix.block_size(), blocked_data[0].size());
for (size_t vector = 0; vector < blocked_matrix.num_vectors(); ++vector) {
for (size_t i = 0; i < blocked_matrix.block_size(); ++i) {
EXPECT_EQ(blocked_matrix.vector(vector)[i],
vector * blocked_matrix.block_size() + i);
}
}
// Check that overriding dimensions is OK. Instead of a matrix that has every
// row as a block, every row is now has two blocks, so there are half as many
// rows and each row (number of columns) is twice as long.
const size_t kNumColumns = internal::kAlignmentBytes / sizeof(float);
store.SetBlockedDimensionOverride("blocked",
{9, 2 * kNumColumns, kNumColumns});
TF_ASSERT_OK(store.Lookup("blocked", &blocked_matrix));
ASSERT_EQ(blocked_matrix.num_rows(), blocked_data.size() / 2);
ASSERT_EQ(blocked_matrix.num_columns(), 2 * blocked_data[0].size());
ASSERT_EQ(blocked_matrix.block_size(), blocked_data[0].size());
}
// Tests that the fake variable never contains variables with unknown format.
TEST(FakeVariableStoreTest, NeverContainsUnknownFormat) {
FakeVariableStore store;
store.AddOrDie("foo", {{0.0}});
std::vector<size_t> dimensions;
AlignedArea area;
EXPECT_THAT(
store.Lookup("foo", VariableSpec::FORMAT_UNKNOWN, &dimensions, &area),
test::IsErrorWithSubstr("Unknown variable"));
}
// Tests that the fake variable store can create a variable that only appears in
// one format.
TEST(FakeVariableStoreTest, AddWithSpecificFormat) {
const auto data = MakeBlockedData();
FakeVariableStore store;
store.AddOrDie("flat", data, VariableSpec::FORMAT_FLAT);
store.AddOrDie("matrix", data, VariableSpec::FORMAT_ROW_MAJOR_MATRIX);
store.AddOrDie("blocked", data,
VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX);
// Vector lookups should only work for "flat".
Vector<float> vector;
TF_ASSERT_OK(store.Lookup("flat", &vector));
EXPECT_THAT(store.Lookup("matrix", &vector),
test::IsErrorWithSubstr("Unknown variable"));
EXPECT_THAT(store.Lookup("blocked", &vector),
test::IsErrorWithSubstr("Unknown variable"));
// Matrix lookups should only work for "matrix".
Matrix<float> matrix;
EXPECT_THAT(store.Lookup("flat", &matrix),
test::IsErrorWithSubstr("Unknown variable"));
TF_ASSERT_OK(store.Lookup("matrix", &matrix));
EXPECT_THAT(store.Lookup("blocked", &matrix),
test::IsErrorWithSubstr("Unknown variable"));
// Blocked matrix lookups should only work for "blocked".
BlockedMatrix<float> blocked_matrix;
EXPECT_THAT(store.Lookup("flat", &blocked_matrix),
test::IsErrorWithSubstr("Unknown variable"));
EXPECT_THAT(store.Lookup("matrix", &blocked_matrix),
test::IsErrorWithSubstr("Unknown variable"));
TF_ASSERT_OK(store.Lookup("blocked", &blocked_matrix));
}
// Tests that Close() always succeeds.
TEST(FakeVariableStoreTest, Close) {
FakeVariableStore store;
TF_EXPECT_OK(store.Close());
store.AddOrDie("foo", {{1.0, 2.0, 3.0}});
TF_EXPECT_OK(store.Close());
store.AddOrDie("bar", {{1.0, 2.0}, {3.0, 4.0}});
TF_EXPECT_OK(store.Close());
}
// Tests that SimpleFakeVariableStore returns the user-specified mock values.
TEST(SimpleFakeVariableStoreTest, ReturnsMockedValues) {
SimpleFakeVariableStore store;
store.MockLookup<float>({1, 2}, {{1.0, 2.0}});
Matrix<float> matrix;
TF_ASSERT_OK(store.Lookup("name_doesnt_matter", &matrix));
ASSERT_EQ(matrix.num_rows(), 1);
ASSERT_EQ(matrix.num_columns(), 2);
EXPECT_EQ(matrix.row(0)[0], 1.0);
EXPECT_EQ(matrix.row(0)[1], 2.0);
TF_ASSERT_OK(store.Close());
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/test/helpers.h"
#include <time.h>
#include <random>
#include "dragnn/runtime/math/transformations.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
UniqueView::UniqueView(size_t size) {
array_.Reset(size);
view_ = array_.view();
}
UniqueArea::UniqueArea(size_t num_views, size_t view_size) {
array_.Reset(ComputeAlignedAreaSize(num_views, view_size));
TF_CHECK_OK(area_.Reset(array_.view(), num_views, view_size));
}
void InitRandomVector(MutableVector<float> vector) {
// clock() is updated less frequently than a cycle counter, so keep around the
// RNG just in case we initialize some vectors in less than a clock tick.
thread_local std::mt19937 *rng = new std::mt19937(clock());
std::normal_distribution<float> distribution(0.0, 1.0);
for (int i = 0; i < vector.size(); i++) {
vector[i] = distribution(*rng);
}
}
void InitRandomMatrix(MutableMatrix<float> matrix) {
// See InitRandomVector comment.
thread_local std::mt19937 *rng = new std::mt19937(clock());
std::normal_distribution<float> distribution(0.0, 1.0);
GenerateMatrix(
matrix.num_rows(), matrix.num_columns(),
[&distribution](int row, int col) { return distribution(*rng); },
&matrix);
}
void AvxVectorFuzzTest(
const std::function<void(AvxFloatVec *vec)> &run,
const std::function<void(float input_value, float output_value)> &check) {
for (int iter = 0; iter < 100; ++iter) {
UniqueVector<float> input(kAvxWidth);
UniqueVector<float> output(kAvxWidth);
InitRandomVector(*input);
InitRandomVector(*output);
AvxFloatVec vec;
vec.Load(input->data());
run(&vec);
vec.Store(output->data());
for (int i = 0; i < kAvxWidth; ++i) {
check((*input)[i], (*output)[i]);
}
}
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Helpers to make it less painful to create instances of aligned values.
// Intended for testing or benchmarking; production code should use managed
// memory allocation, for example Operands.
#ifndef DRAGNN_RUNTIME_TEST_HELPERS_H_
#define DRAGNN_RUNTIME_TEST_HELPERS_H_
#include <stddef.h>
#include <algorithm>
#include <functional>
#include <vector>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/avx_vector_array.h"
#include "dragnn/runtime/math/types.h"
#include <gmock/gmock.h>
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// An aligned view and its uniquely-owned underlying storage. Can be used like
// a std::unique_ptr<MutableAlignedView>.
class UniqueView {
public:
// Creates a view of |size| uninitialized bytes.
explicit UniqueView(size_t size);
// Provides std::unique_ptr-like access.
MutableAlignedView *get() { return &view_; }
MutableAlignedView &operator*() { return view_; }
MutableAlignedView *operator->() { return &view_; }
private:
// View and its underlying storage.
UniqueAlignedArray array_;
MutableAlignedView view_;
};
// An aligned area and its uniquely-owned underlying storage. Can be used like
// a std::unique_ptr<MutableAlignedArea>.
class UniqueArea {
public:
// Creates an area with |num_views| sub-views, each of which has |view_size|
// uninitialized bytes. Check-fails on error.
UniqueArea(size_t num_views, size_t view_size);
// Provides std::unique_ptr-like access.
MutableAlignedArea *get() { return &area_; }
MutableAlignedArea &operator*() { return area_; }
MutableAlignedArea *operator->() { return &area_; }
private:
// Area and its underlying storage.
UniqueAlignedArray array_;
MutableAlignedArea area_;
};
// A vector and its uniquely-owned underlying storage. Can be used like a
// std::unique_ptr<MutableVector<T>>.
template <class T>
class UniqueVector {
public:
// Creates an empty vector.
UniqueVector() : UniqueVector(0) {}
// Creates a vector with |dimension| uninitialized Ts.
explicit UniqueVector(size_t dimension)
: view_(dimension * sizeof(T)), vector_(*view_) {}
// Creates a vector initialized to hold the |values|.
explicit UniqueVector(const std::vector<T> &values);
// Provides std::unique_ptr-like access.
MutableVector<T> *get() { return &vector_; }
MutableVector<T> &operator*() { return vector_; }
MutableVector<T> *operator->() { return &vector_; }
// Returns a view pointing to the same memory.
MutableAlignedView view() { return *view_; }
private:
// Vector and its underlying view.
UniqueView view_;
MutableVector<T> vector_;
};
// A matrix and its uniquely-owned underlying storage. Can be used like a
// std::unique_ptr<MutableMatrix<T>>>.
template <class T>
class UniqueMatrix {
public:
// Creates an empty matrix.
UniqueMatrix() : UniqueMatrix(0, 0) {}
// Creates a matrix with |num_rows| x |num_columns| uninitialized Ts.
UniqueMatrix(size_t num_rows, size_t num_columns)
: area_(num_rows, num_columns * sizeof(T)), matrix_(*area_) {}
// Creates a matrix initialized to hold the |values|.
explicit UniqueMatrix(const std::vector<std::vector<T>> &values);
// Provides std::unique_ptr-like access.
MutableMatrix<T> *get() { return &matrix_; }
MutableMatrix<T> &operator*() { return matrix_; }
MutableMatrix<T> *operator->() { return &matrix_; }
// Returns an area pointing to the same memory.
MutableAlignedArea area() { return *area_; }
private:
// Matrix and its underlying area.
UniqueArea area_;
MutableMatrix<T> matrix_;
};
// Implementation details below.
template <class T>
UniqueVector<T>::UniqueVector(const std::vector<T> &values)
: UniqueVector(values.size()) {
std::copy(values.begin(), values.end(), vector_.begin());
}
template <class T>
UniqueMatrix<T>::UniqueMatrix(const std::vector<std::vector<T>> &values)
: UniqueMatrix(values.size(), values.empty() ? 0 : values[0].size()) {
for (size_t i = 0; i < values.size(); ++i) {
CHECK_EQ(values[0].size(), values[i].size());
std::copy(values[i].begin(), values[i].end(), matrix_.row(i).begin());
}
}
// Expects that the |matrix| contains the |data|.
template <class T>
void ExpectMatrix(Matrix<T> matrix, const std::vector<std::vector<T>> &data) {
ASSERT_EQ(matrix.num_rows(), data.size());
if (data.empty()) return;
ASSERT_EQ(matrix.num_columns(), data[0].size());
for (size_t row = 0; row < data.size(); ++row) {
for (size_t column = 0; column < data[row].size(); ++column) {
EXPECT_EQ(matrix.row(row)[column], data[row][column]);
}
}
}
// Initializes a floating-point vector with random values, using a normal
// distribution centered at 0 with standard deviation 1.
void InitRandomVector(MutableVector<float> vector);
void InitRandomMatrix(MutableMatrix<float> matrix);
// Fuzz test using AVX vectors.
// If this file gets too big, move into something like math/test_helpers.h.
void AvxVectorFuzzTest(
const std::function<void(AvxFloatVec *vec)> &run,
const std::function<void(float input_value, float output_value)> &check);
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_TEST_HELPERS_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/test/helpers.h"
#include <string>
#include "dragnn/runtime/alignment.h"
#include "dragnn/runtime/math/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Fills the |slice| with the |value|. Slice must have .data() and .size().
template <class Slice, class T>
void Fill(Slice slice, T value) {
for (size_t i = 0; i < slice.size(); ++i) slice.data()[i] = value;
}
// Returns the sum of all elements in the |slice|, casted to double. Slice must
// have .data() and .size().
template <class Slice>
double Sum(Slice slice) {
double sum = 0.0;
for (size_t i = 0; i < slice.size(); ++i) {
sum += static_cast<double>(slice.data()[i]);
}
return sum;
}
// Expects that the two pointers have the same address.
void ExpectSameAddress(const void *pointer1, const void *pointer2) {
EXPECT_EQ(pointer1, pointer2);
}
// Tests that each byte of a UniqueView is usable.
TEST(UniqueViewTest, Usable) {
UniqueView view(100);
EXPECT_EQ(view->size(), 100);
Fill(*view, 'x');
LOG(INFO) << "Prevents elision by optimizer: " << Sum(*view);
EXPECT_EQ(view->data()[0], 'x');
}
// Tests that each byte of a UniqueArea is usable.
TEST(UniqueAreaTest, Usable) {
UniqueArea area(10, 100);
EXPECT_EQ(area->num_views(), 10);
EXPECT_EQ(area->view_size(), 100);
for (size_t i = 0; i < 10; ++i) {
Fill(area->view(i), 'y');
LOG(INFO) << "Prevents elision by optimizer: " << Sum(area->view(i));
EXPECT_EQ(area->view(i).data()[0], 'y');
}
}
// Tests that UniqueVector is empty by default.
TEST(UniqueVectorTest, EmptyByDefault) {
UniqueVector<float> vector;
EXPECT_EQ(vector->size(), 0);
}
// Tests that each element of a UniqueVector is usable.
TEST(UniqueVectorTest, Usable) {
UniqueVector<float> vector(100);
EXPECT_EQ(vector->size(), 100);
Fill(*vector, 1.5);
LOG(INFO) << "Prevents elision by optimizer: " << Sum(*vector);
EXPECT_EQ((*vector)[0], 1.5);
}
// Tests that UniqueVector also exports a view.
TEST(UniqueVectorTest, View) {
UniqueVector<float> vector(123);
ExpectSameAddress(vector.view().data(), vector->data());
EXPECT_EQ(vector.view().size(), 123 * sizeof(float));
}
// Tests that a UniqueVector can be constructed with an initial value.
TEST(UniqueVectorTest, Initialization) {
UniqueVector<int> vector({2, 3, 5, 7});
EXPECT_EQ(vector->size(), 4);
EXPECT_EQ((*vector)[0], 2);
EXPECT_EQ((*vector)[1], 3);
EXPECT_EQ((*vector)[2], 5);
EXPECT_EQ((*vector)[3], 7);
}
// Tests that UniqueMatrix is empty by default.
TEST(UniqueMatrixTest, EmptyByDefault) {
UniqueMatrix<float> row_major_matrix;
EXPECT_EQ(row_major_matrix->num_rows(), 0);
EXPECT_EQ(row_major_matrix->num_columns(), 0);
}
// Tests that each element of a UniqueMatrix is usable.
TEST(UniqueMatrixTest, Usable) {
UniqueMatrix<float> row_major_matrix(10, 100);
EXPECT_EQ(row_major_matrix->num_rows(), 10);
EXPECT_EQ(row_major_matrix->num_columns(), 100);
for (size_t i = 0; i < 10; ++i) {
Fill(row_major_matrix->row(i), 1.75);
LOG(INFO) << "Prevents elision by optimizer: "
<< Sum(row_major_matrix->row(i));
EXPECT_EQ(row_major_matrix->row(i)[0], 1.75);
}
}
// Tests that UniqueMatrix also exports an area.
TEST(UniqueMatrixTest, Area) {
UniqueMatrix<float> row_major_matrix(12, 34);
ExpectSameAddress(row_major_matrix.area().view(0).data(),
row_major_matrix->row(0).data());
EXPECT_EQ(row_major_matrix.area().num_views(), 12);
EXPECT_EQ(row_major_matrix.area().view_size(), 34 * sizeof(float));
}
// Tests that a UniqueMatrix can be constructed with an initial value.
TEST(UniqueMatrixTest, Initialization) {
UniqueMatrix<int> row_major_matrix({{2, 3, 5}, {7, 11, 13}});
EXPECT_EQ(row_major_matrix->num_rows(), 2);
EXPECT_EQ(row_major_matrix->num_columns(), 3);
EXPECT_EQ(row_major_matrix->row(0)[0], 2);
EXPECT_EQ(row_major_matrix->row(0)[1], 3);
EXPECT_EQ(row_major_matrix->row(0)[2], 5);
EXPECT_EQ(row_major_matrix->row(1)[0], 7);
EXPECT_EQ(row_major_matrix->row(1)[1], 11);
EXPECT_EQ(row_major_matrix->row(1)[2], 13);
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/protos/data.pb.h"
#include "dragnn/runtime/flexible_matrix_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::InSequence;
using ::testing::Return;
// Fills the |matrix| with the |fill_value|.
void Fill(float fill_value, MutableMatrix<float> matrix) {
for (size_t i = 0; i < matrix.num_rows(); ++i) {
for (float &value : matrix.row(i)) value = fill_value;
}
}
} // namespace
constexpr char NetworkTestBase::kTestComponentName[];
void NetworkTestBase::TearDown() {
// The state extensions may contain objects that cannot outlive the component,
// so discard the extensions early. This is not an issue in real-world usage,
// as the Master calls destructors in the right order.
session_state_.extensions = Extensions();
}
NetworkTestBase::GetInputFeaturesFunctor NetworkTestBase::ExtractFeatures(
int expected_channel_id, const std::vector<Feature> &features) {
return [=](const string &component_name,
std::function<int32 *(int)> allocate_indices,
std::function<int64 *(int)> allocate_ids,
std::function<float *(int)> allocate_weights, int channel_id) {
EXPECT_EQ(component_name, kTestComponentName);
EXPECT_EQ(channel_id, expected_channel_id);
const int num_features = features.size();
int32 *indices = allocate_indices(num_features);
int64 *ids = allocate_ids(num_features);
float *weights = allocate_weights(num_features);
for (int i = 0; i < num_features; ++i) {
indices[i] = features[i].index;
ids[i] = features[i].id;
weights[i] = features[i].weight;
}
return num_features;
};
}
NetworkTestBase::GetTranslatedLinkFeaturesFunctor NetworkTestBase::ExtractLinks(
int expected_channel_id, const std::vector<string> &features_text) {
std::vector<LinkFeatures> features;
for (const string &text : features_text) {
features.emplace_back();
CHECK(TextFormat::ParseFromString(text, &features.back()));
}
return [=](const string &component_name, int channel_id) {
EXPECT_EQ(component_name, kTestComponentName);
EXPECT_EQ(channel_id, expected_channel_id);
return features;
};
}
void NetworkTestBase::AddVectorVariable(const string &name, size_t dimension,
float fill_value) {
const std::vector<float> row(dimension, fill_value);
const std::vector<std::vector<float>> values(1, row);
variable_store_.AddOrDie(name, values);
}
void NetworkTestBase::AddMatrixVariable(const string &name, size_t num_rows,
size_t num_columns, float fill_value) {
const std::vector<float> row(num_columns, fill_value);
const std::vector<std::vector<float>> values(num_rows, row);
variable_store_.AddOrDie(name, values);
}
void NetworkTestBase::AddFixedEmbeddingMatrix(int channel_id,
size_t vocabulary_size,
size_t embedding_dim,
float fill_value) {
const string name = tensorflow::strings::StrCat(
kTestComponentName, "/fixed_embedding_matrix_", channel_id, "/trimmed");
AddMatrixVariable(name, vocabulary_size, embedding_dim, fill_value);
}
void NetworkTestBase::AddLinkedWeightMatrix(int channel_id, size_t source_dim,
size_t embedding_dim,
float fill_value) {
const string name = tensorflow::strings::StrCat(
kTestComponentName, "/linked_embedding_matrix_", channel_id, "/weights",
FlexibleMatrixKernel::kSuffix);
AddMatrixVariable(name, embedding_dim, source_dim, fill_value);
}
void NetworkTestBase::AddLinkedOutOfBoundsVector(int channel_id,
size_t embedding_dim,
float fill_value) {
const string name = tensorflow::strings::StrCat(kTestComponentName,
"/linked_embedding_matrix_",
channel_id, "/out_of_bounds");
AddVectorVariable(name, embedding_dim, fill_value);
}
void NetworkTestBase::AddComponent(const string &component_name) {
TF_ASSERT_OK(network_state_manager_.AddComponent(component_name));
}
void NetworkTestBase::AddLayer(const string &layer_name, size_t dimension) {
LayerHandle<float> unused_layer_handle;
TF_ASSERT_OK(network_state_manager_.AddLayer(layer_name, dimension,
&unused_layer_handle));
}
void NetworkTestBase::AddPairwiseLayer(const string &layer_name,
size_t dimension) {
PairwiseLayerHandle<float> unused_layer_handle;
TF_ASSERT_OK(network_state_manager_.AddLayer(layer_name, dimension,
&unused_layer_handle));
}
void NetworkTestBase::StartComponent(size_t num_steps) {
// The pre-allocation hint is arbitrary, but setting it to a small value
// exercises reallocations.
TF_ASSERT_OK(network_states_.StartNextComponent(5));
for (size_t i = 0; i < num_steps; ++i) network_states_.AddStep();
}
MutableMatrix<float> NetworkTestBase::GetLayer(const string &component_name,
const string &layer_name) const {
size_t unused_dimension = 0;
LayerHandle<float> handle;
TF_CHECK_OK(network_state_manager_.LookupLayer(component_name, layer_name,
&unused_dimension, &handle));
return network_states_.GetLayer(handle);
}
MutableMatrix<float> NetworkTestBase::GetPairwiseLayer(
const string &component_name, const string &layer_name) const {
size_t unused_dimension = 0;
PairwiseLayerHandle<float> handle;
TF_CHECK_OK(network_state_manager_.LookupLayer(component_name, layer_name,
&unused_dimension, &handle));
return network_states_.GetLayer(handle);
}
void NetworkTestBase::FillLayer(const string &component_name,
const string &layer_name,
float fill_value) const {
Fill(fill_value, GetLayer(component_name, layer_name));
}
void NetworkTestBase::SetupTransitionLoop(size_t num_steps) {
// Return not terminal |num_steps| times, then return terminal.
InSequence scoped;
EXPECT_CALL(compute_session_, IsTerminal(kTestComponentName))
.Times(num_steps)
.WillRepeatedly(Return(false))
.RetiresOnSaturation();
EXPECT_CALL(compute_session_, IsTerminal(kTestComponentName))
.WillOnce(Return(true));
}
void NetworkTestBase::ExpectVector(Vector<float> vector, size_t dimension,
float expected_value) {
ASSERT_EQ(vector.size(), dimension);
for (const float value : vector) EXPECT_EQ(value, expected_value);
}
void NetworkTestBase::ExpectMatrix(Matrix<float> matrix, size_t num_rows,
size_t num_columns, float expected_value) {
ASSERT_EQ(matrix.num_rows(), num_rows);
ASSERT_EQ(matrix.num_columns(), num_columns);
for (size_t row = 0; row < num_rows; ++row) {
ExpectVector(matrix.row(row), num_columns, expected_value);
}
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment