Commit a4bb31d0 authored by Terry Koo's avatar Terry Koo
Browse files

Export @195097388.

parent dea7ecf6
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SESSION_STATE_H_
#define DRAGNN_RUNTIME_SESSION_STATE_H_
#include "dragnn/runtime/extensions.h"
#include "dragnn/runtime/network_states.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// State associated with a ComputeSession being evaluated by a DRAGNN network,
// reusable across multiple evaluations. Unlike the ComputeSession, which is
// both the input and output of the network, this state is strictly internal to
// the network. Production code should allocate these via a SessionStatePool.
struct SessionState {
// The network states that connect the pipeline of components.
NetworkStates network_states;
// Generic set of typed extensions.
Extensions extensions;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_SESSION_STATE_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/session_state_pool.h"
#include <algorithm>
namespace syntaxnet {
namespace dragnn {
namespace runtime {
SessionStatePool::SessionStatePool(size_t max_free_states)
: max_free_states_(max_free_states) {}
std::unique_ptr<SessionState> SessionStatePool::Acquire() {
{ // Exclude the slow path from the critical region.
tensorflow::mutex_lock lock(mutex_);
if (!free_list_.empty()) {
// Fast path: reuse a free state.
std::unique_ptr<SessionState> state = std::move(free_list_.back());
free_list_.pop_back();
return state;
}
}
// Slow path: allocate a new state.
return std::unique_ptr<SessionState>(new SessionState());
}
void SessionStatePool::Release(std::unique_ptr<SessionState> state) {
{ // Exclude the slow path from the critical region.
tensorflow::mutex_lock lock(mutex_);
if (free_list_.size() < max_free_states_) {
// Fast path: reclaim in the free list.
free_list_.emplace_back(std::move(state));
return;
}
}
// Slow path: discard the excess |state| when it goes out of scope.
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_RUNTIME_SESSION_STATE_POOL_H_
#define DRAGNN_RUNTIME_SESSION_STATE_POOL_H_
#include <stddef.h>
#include <memory>
#include <utility>
#include "dragnn/runtime/session_state.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// A thread-safe pool of session states that maintains a free list. The free
// list is bounded, so a spike in usage does not permanently increase the size
// of the pool. Use ScopedSessionState to interact with the pool.
class SessionStatePool {
public:
// Creates a pool whose free list holds at most |max_free_states| states.
//
// If usage spikes are not a concern (e.g., during offline processing where
// the runtime is called from a fixed-size pool of threads), then specify a
// large value like SIZE_MAX. That eliminates unnecessary deallocations and
// reallocations, and eliminates the need to coordinate the thread pool size
// with this pool's size.
//
// If memory usage dominates CPU usage, then specify 0 to eliminate overhead
// from the free list.
//
// TODO(googleuser): An alternative is to set a target allocation
// rate (e.g., 2% of Acquire()s should create a new state), and let the pool
// adapt its free list size to achieve that rate.
explicit SessionStatePool(size_t max_free_states);
private:
friend class ScopedSessionState;
// Returns a state acquired from this pool. The caller is the exclusive user
// of the returned state until it is passed to Release().
std::unique_ptr<SessionState> Acquire();
// Releases the |state| back to this pool. The |state| must be the result of
// a previous Acquire(). The caller can no longer use the |state|.
void Release(std::unique_ptr<SessionState> state);
// Maximum number of states to keep in the |free_list_|.
const size_t max_free_states_;
// Mutex guarding the |free_list_|.
tensorflow::mutex mutex_;
// List of previously-Release()d states.
std::vector<std::unique_ptr<SessionState>> free_list_ GUARDED_BY(mutex_);
};
// RAII wrapper that manages a session state acquired from a pool. The wrapped
// state is usable during the lifetime of the wrapper.
class ScopedSessionState {
public:
// Implements RAII semantics.
explicit ScopedSessionState(SessionStatePool *pool)
: pool_(pool), state_(pool_->Acquire()) {}
~ScopedSessionState() { pool_->Release(std::move(state_)); }
// Prevents double-release.
ScopedSessionState(const ScopedSessionState &that) = delete;
ScopedSessionState &operator=(const ScopedSessionState &that) = delete;
// Provides std::unique_ptr-like access.
SessionState *get() const { return state_.get(); }
SessionState &operator*() const { return *get(); }
SessionState *operator->() const { return get(); }
private:
// Pool from which the |state_| was acquired.
SessionStatePool *const pool_;
// Wrapped session state.
std::unique_ptr<SessionState> state_;
};
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_SESSION_STATE_POOL_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/session_state_pool.h"
#include <stddef.h>
#include <set>
#include "dragnn/runtime/session_state.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Maximum number of free states.
static constexpr size_t kMaxFreeStates = 16;
class SessionStatePoolTest : public ::testing::Test {
protected:
SessionStatePool pool_{kMaxFreeStates};
};
// Tests that ScopedSessionState can be used to acquire a valid state.
TEST_F(SessionStatePoolTest, ScopedWrapper) {
const ScopedSessionState state(&pool_);
EXPECT_TRUE(state.get()); // non-null
}
// Tests that the active states claimed from the pool are unique.
TEST_F(SessionStatePoolTest, UniqueActiveStates) {
// NB: Don't use std::unique_ptr<ScopedSessionState> in real code. The test
// does this because it's otherwise difficult to acquire lots of states.
std::vector<std::unique_ptr<ScopedSessionState>> states;
for (size_t i = 0; i < 100; ++i) {
states.emplace_back(new ScopedSessionState(&pool_));
}
// Check that all of the states are unique.
std::set<const SessionState *> state_ptrs;
for (const auto &state : states) {
EXPECT_TRUE(state_ptrs.insert(state->get()).second);
}
EXPECT_TRUE(state_ptrs.find(nullptr) == state_ptrs.end());
}
// Tests that active states, when released, are reclaimed and reused.
TEST_F(SessionStatePoolTest, Reuse) {
std::set<const SessionState *> state_ptrs;
{ // Grab exactly as many states as the free list can hold.
std::vector<std::unique_ptr<ScopedSessionState>> states;
for (size_t i = 0; i < kMaxFreeStates; ++i) {
states.emplace_back(new ScopedSessionState(&pool_));
EXPECT_TRUE(state_ptrs.insert(states.back()->get()).second);
}
}
{ // Grab the same number of states again and check that they are the same
// objects we saw in the first loop.
std::vector<std::unique_ptr<ScopedSessionState>> states;
for (size_t i = 0; i < kMaxFreeStates; ++i) {
states.emplace_back(new ScopedSessionState(&pool_));
EXPECT_FALSE(state_ptrs.insert(states.back()->get()).second);
}
}
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Returns true if the |component_type| can be transformed by this.
bool ShouldTransform(const string &component_type) {
for (const char *supported_type : {
"SyntaxNetHeadSelectionComponent", //
"SyntaxNetMstSolverComponent", //
}) {
if (component_type == supported_type) return true;
}
return false;
}
// Changes the backend for some components to StatelessComponent.
class StatelessComponentTransformer : public ComponentTransformer {
public:
// Implements ComponentTransformer.
tensorflow::Status Transform(const string &component_type,
ComponentSpec *component_spec) override {
if (ShouldTransform(component_type)) {
component_spec->mutable_backend()->set_registered_name(
"StatelessComponent");
}
return tensorflow::Status::OK();
}
};
DRAGNN_RUNTIME_REGISTER_COMPONENT_TRANSFORMER(StatelessComponentTransformer);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component_transformation.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Arbitrary supported component type.
constexpr char kSupportedComponentType[] = "SyntaxNetHeadSelectionComponent";
// Returns a ComponentSpec that is supported by the transformer.
ComponentSpec MakeSupportedSpec() {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name(
kSupportedComponentType);
return component_spec;
}
// Tests that a compatible spec is modified to use StatelessComponent.
TEST(StatelessComponentTransformerTest, Compatible) {
ComponentSpec component_spec = MakeSupportedSpec();
ComponentSpec expected_spec = component_spec;
expected_spec.mutable_backend()->set_registered_name("StatelessComponent");
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(expected_spec));
}
// Tests that other component specs are not modified.
TEST(StatelessComponentTransformerTest, Incompatible) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_component_builder()->set_registered_name("other");
const ComponentSpec expected_spec = component_spec;
TF_ASSERT_OK(ComponentTransformer::ApplyAll(&component_spec));
EXPECT_THAT(component_spec, test::EqualsProto(expected_spec));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/term_map_sequence_extractor.h"
#include "dragnn/runtime/term_map_utils.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "dragnn/runtime/unicode_dictionary.h"
#include "syntaxnet/base.h"
#include "syntaxnet/segmenter_utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "util/utf8/unicodetext.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Sequence extractor that extracts characters from a SyntaxNetComponent batch.
class SyntaxNetCharacterSequenceExtractor
: public TermMapSequenceExtractor<UnicodeDictionary> {
public:
SyntaxNetCharacterSequenceExtractor();
// Implements SequenceExtractor.
bool Supports(const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) const override;
tensorflow::Status Initialize(const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) override;
tensorflow::Status GetIds(InputBatchCache *input,
std::vector<int32> *ids) const override;
private:
// Parses |fml| and sets |min_frequency| and |max_num_terms| to the specified
// values. If the |fml| does not specify a supported feature, returns non-OK
// and modifies nothing.
static tensorflow::Status ParseFml(const string &fml, int *min_frequency,
int *max_num_terms);
// Feature IDs for break characters and unknown characters.
int32 break_id_ = -1;
int32 unknown_id_ = -1;
};
SyntaxNetCharacterSequenceExtractor::SyntaxNetCharacterSequenceExtractor()
: TermMapSequenceExtractor("char-map") {}
tensorflow::Status SyntaxNetCharacterSequenceExtractor::ParseFml(
const string &fml, int *min_frequency, int *max_num_terms) {
return ParseTermMapFml(fml, {"char-input", "text-char"}, min_frequency,
max_num_terms);
}
bool SyntaxNetCharacterSequenceExtractor::Supports(
const FixedFeatureChannel &channel,
const ComponentSpec &component_spec) const {
TransitionSystemTraits traits(component_spec);
int unused_min_frequency = 0;
int unused_max_num_terms = 0;
const tensorflow::Status parse_fml_status =
ParseFml(channel.fml(), &unused_min_frequency, &unused_max_num_terms);
return TermMapSequenceExtractor::SupportsTermMap(channel, component_spec) &&
parse_fml_status.ok() &&
component_spec.backend().registered_name() == "SyntaxNetComponent" &&
traits.is_sequential && traits.is_character_scale;
}
tensorflow::Status SyntaxNetCharacterSequenceExtractor::Initialize(
const FixedFeatureChannel &channel, const ComponentSpec &component_spec) {
int min_frequency = 0;
int max_num_terms = 0;
TF_RETURN_IF_ERROR(ParseFml(channel.fml(), &min_frequency, &max_num_terms));
TF_RETURN_IF_ERROR(TermMapSequenceExtractor::InitializeTermMap(
channel, component_spec, min_frequency, max_num_terms));
const int num_known = term_map().size();
break_id_ = num_known;
unknown_id_ = break_id_ + 1;
const int map_vocab_size = unknown_id_ + 1;
const int spec_vocab_size = channel.vocabulary_size();
if (map_vocab_size != spec_vocab_size) {
return tensorflow::errors::InvalidArgument(
"Character vocabulary size mismatch between term map (", map_vocab_size,
") and ComponentSpec (", spec_vocab_size, ")");
}
return tensorflow::Status::OK();
}
tensorflow::Status SyntaxNetCharacterSequenceExtractor::GetIds(
InputBatchCache *input, std::vector<int32> *ids) const {
ids->clear();
const std::vector<SyntaxNetSentence> &data =
*input->GetAs<SentenceInputBatch>()->data();
if (data.size() != 1) {
return tensorflow::errors::InvalidArgument("Non-singleton batch: got ",
data.size(), " elements");
}
const Sentence &sentence = *data[0].sentence();
if (sentence.token_size() == 0) return tensorflow::Status::OK();
const string &text = sentence.text();
const int start_byte = sentence.token(0).start();
const int end_byte = sentence.token(sentence.token_size() - 1).end();
const int num_bytes = end_byte - start_byte + 1;
string character;
UnicodeText unicode_text;
unicode_text.PointToUTF8(text.data() + start_byte, num_bytes);
const auto end = unicode_text.end();
for (auto it = unicode_text.begin(); it != end; ++it) {
character.assign(it.utf8_data(), it.utf8_length());
if (SegmenterUtils::IsBreakChar(character)) {
ids->push_back(break_id_);
} else {
ids->push_back(
term_map().Lookup(character.data(), character.size(), unknown_id_));
}
}
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_SEQUENCE_EXTRACTOR(SyntaxNetCharacterSequenceExtractor);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_extractor.h"
#include "dragnn/runtime/test/term_map_helpers.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
constexpr char kResourceName[] = "char-map";
// Returns a ComponentSpec parsed from the |text| that contains a term map
// resource pointing at the |path|.
ComponentSpec MakeSpec(const string &text, const string &path) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(text, &component_spec));
AddTermMapResource(kResourceName, path, &component_spec);
return component_spec;
}
// Returns a supported ComponentSpec that points at the term map in the |path|.
ComponentSpec MakeSupportedSpec(const string &path = "/dev/null") {
return MakeSpec(R"(transition_system { registered_name: 'char-shift-only' }
backend { registered_name: 'SyntaxNetComponent' }
fixed_feature {} # breaks hard-coded refs to channel 0
fixed_feature { size: 1 fml: 'char-input.text-char' })",
path);
}
// Returns a default sentence.
Sentence MakeSentence() {
Sentence sentence;
sentence.set_text("a bc def");
Token *token = sentence.add_token();
token->set_start(0);
token->set_end(sentence.text().size() - 1);
token->set_word(sentence.text());
return sentence;
}
// Tests that the extractor supports an appropriate spec.
TEST(SyntaxNetCharacterSequenceExtractorTest, Supported) {
const ComponentSpec component_spec = MakeSupportedSpec();
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
string name;
TF_ASSERT_OK(SequenceExtractor::Select(channel, component_spec, &name));
EXPECT_EQ(name, "SyntaxNetCharacterSequenceExtractor");
}
// Tests that the extractor requires the proper backend.
TEST(SyntaxNetCharacterSequenceExtractorTest, WrongBackend) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_backend()->set_registered_name("bad");
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
string name;
EXPECT_THAT(
SequenceExtractor::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceExtractor supports channel"));
}
// Tests that the extractor requires the proper transition system.
TEST(SyntaxNetCharacterSequenceExtractorTest, WrongTransitionSystem) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_transition_system()->set_registered_name("bad");
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
string name;
EXPECT_THAT(
SequenceExtractor::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceExtractor supports channel"));
}
// Tests that the extractor requires the proper FML.
TEST(SyntaxNetCharacterSequenceExtractorTest, WrongFml) {
ComponentSpec component_spec = MakeSupportedSpec();
component_spec.mutable_fixed_feature(1)->set_fml("bad");
const FixedFeatureChannel &channel = component_spec.fixed_feature(1);
string name;
EXPECT_THAT(
SequenceExtractor::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceExtractor supports channel"));
}
// Tests that the extractor can be initialized and used to extract feature IDs.
TEST(SyntaxNetCharacterSequenceExtractorTest, InitializeAndGetIds) {
// Terms are sorted by descending frequency, so this ensures a=0, b=1, etc.
const string path =
WriteTermMap({{"a", 5}, {"b", 4}, {"c", 3}, {"d", 2}, {"e", 1}});
ComponentSpec component_spec = MakeSupportedSpec(path);
FixedFeatureChannel &channel = *component_spec.mutable_fixed_feature(1);
channel.set_vocabulary_size(7);
std::unique_ptr<SequenceExtractor> extractor;
TF_ASSERT_OK(SequenceExtractor::New("SyntaxNetCharacterSequenceExtractor",
channel, component_spec, &extractor));
const Sentence sentence = MakeSentence();
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> ids;
TF_ASSERT_OK(extractor->GetIds(&input, &ids));
// 0-4 = 'a' to 'e'
// 5 = break chars (whitespace)
// 6 = unknown chars (e.g., 'f')
const std::vector<int32> expected_ids = {0, 5, 1, 2, 5, 3, 4, 6};
EXPECT_EQ(ids, expected_ids);
}
// Tests that an empty term map works.
TEST(SyntaxNetCharacterSequenceExtractorTest, EmptyTermMap) {
const string path = WriteTermMap({});
ComponentSpec component_spec = MakeSupportedSpec(path);
FixedFeatureChannel &channel = *component_spec.mutable_fixed_feature(1);
channel.set_vocabulary_size(2);
std::unique_ptr<SequenceExtractor> extractor;
TF_ASSERT_OK(SequenceExtractor::New("SyntaxNetCharacterSequenceExtractor",
channel, component_spec, &extractor));
const Sentence sentence = MakeSentence();
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> ids = {1, 2, 3, 4}; // should be overwritten
TF_ASSERT_OK(extractor->GetIds(&input, &ids));
const std::vector<int32> expected_ids = {1, 0, 1, 1, 0, 1, 1, 1};
EXPECT_EQ(ids, expected_ids);
}
// Tests that GetIds() fails if the batch is the wrong size.
TEST(SyntaxNetCharacterSequenceExtractorTest, WrongBatchSize) {
const string path = WriteTermMap({});
ComponentSpec component_spec = MakeSupportedSpec(path);
FixedFeatureChannel &channel = *component_spec.mutable_fixed_feature(1);
channel.set_vocabulary_size(2);
std::unique_ptr<SequenceExtractor> extractor;
TF_ASSERT_OK(SequenceExtractor::New("SyntaxNetCharacterSequenceExtractor",
channel, component_spec, &extractor));
const Sentence sentence = MakeSentence();
const std::vector<string> data = {sentence.SerializeAsString(),
sentence.SerializeAsString()};
InputBatchCache input(data);
std::vector<int32> ids;
EXPECT_THAT(extractor->GetIds(&input, &ids),
test::IsErrorWithSubstr("Non-singleton batch: got 2 elements"));
}
// Tests that initialization fails if the vocabulary size does not match.
TEST(SyntaxNetCharacterSequenceExtractorTest, WrongVocabularySize) {
const string path = WriteTermMap({});
ComponentSpec component_spec = MakeSupportedSpec(path);
FixedFeatureChannel &channel = *component_spec.mutable_fixed_feature(1);
channel.set_vocabulary_size(1000);
std::unique_ptr<SequenceExtractor> extractor;
EXPECT_THAT(
SequenceExtractor::New("SyntaxNetCharacterSequenceExtractor",
channel, component_spec, &extractor),
test::IsErrorWithSubstr("Character vocabulary size mismatch between term "
"map (2) and ComponentSpec (1000)"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "dragnn/runtime/transition_system_traits.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "util/utf8/unilib_utf8_utils.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Focus character to link to in each token.
enum class Focus {
kFirst, // first character in token
kLast, // last character in token
};
// Translator to apply to the linked character index.
enum class Translator {
kIdentity, // direct identity link
kReversed, // reverse-order link
};
// Returns the LinkedFeatureChannel.fml for the |focus|.
string ChannelFml(Focus focus) {
switch (focus) {
case Focus::kFirst:
return "input.first-char-focus";
case Focus::kLast:
return "input.last-char-focus";
}
}
// Returns the LinkedFeatureChannel.source_translator for the |translator|.
string ChannelTranslator(Translator translator) {
switch (translator) {
case Translator::kIdentity:
return "identity";
case Translator::kReversed:
return "reverse-char";
}
}
// Returns the |focus| byte index for the |token|. The returned index must be
// within the span of the |token|.
int32 GetFocusByte(Focus focus, const Token &token) {
switch (focus) {
case Focus::kFirst:
return token.start();
case Focus::kLast:
return token.end();
}
}
// Applies the |translator| to the character |index| w.r.t. the |last_index| and
// returns the result.
int32 Translate(Translator translator, int32 last_index, int32 index) {
switch (translator) {
case Translator::kIdentity:
return index;
case Translator::kReversed:
return last_index - index;
}
}
// Translates links from tokens in the target layer to UTF-8 characters in the
// source layer. Templated on a |focus| and |translator| (see above).
template <Focus focus, Translator translator>
class SyntaxNetCharacterSequenceLinker : public SequenceLinker {
public:
// Implements SequenceLinker.
bool Supports(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const override;
tensorflow::Status Initialize(const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) override;
tensorflow::Status GetLinks(size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const override;
};
template <Focus focus, Translator translator>
bool SyntaxNetCharacterSequenceLinker<focus, translator>::Supports(
const LinkedFeatureChannel &channel,
const ComponentSpec &component_spec) const {
TransitionSystemTraits traits(component_spec);
return channel.fml() == ChannelFml(focus) &&
channel.source_translator() == ChannelTranslator(translator) &&
component_spec.backend().registered_name() == "SyntaxNetComponent" &&
traits.is_sequential && traits.is_token_scale;
}
template <Focus focus, Translator translator>
tensorflow::Status
SyntaxNetCharacterSequenceLinker<focus, translator>::Initialize(
const LinkedFeatureChannel &channel, const ComponentSpec &component_spec) {
return tensorflow::Status::OK();
}
template <Focus focus, Translator translator>
tensorflow::Status
SyntaxNetCharacterSequenceLinker<focus, translator>::GetLinks(
size_t source_num_steps, InputBatchCache *input,
std::vector<int32> *links) const {
const std::vector<SyntaxNetSentence> &batch =
*input->GetAs<SentenceInputBatch>()->data();
if (batch.size() != 1) {
return tensorflow::errors::InvalidArgument("Non-singleton batch: got ",
batch.size(), " elements");
}
const Sentence &sentence = *batch[0].sentence();
const int32 num_tokens = sentence.token_size();
links->resize(num_tokens);
if (num_tokens == 0) return tensorflow::Status::OK();
// Given the properties selected in Supports(), the number of source steps
// must match the number of UTF-8 characters. The last character index will
// be used in Translate().
const int32 last_char_index = static_cast<int32>(source_num_steps) - 1;
// [start,end) byte range of the text spanned by the sentence tokens.
const int32 start_byte = sentence.token(0).start();
const int32 end_byte = sentence.token(num_tokens - 1).end() + 1;
const char *const data = sentence.text().data();
if (UniLib::IsTrailByte(data[start_byte])) {
return tensorflow::errors::InvalidArgument(
"First token starts in the middle of a UTF-8 character: ",
sentence.token(0).ShortDebugString());
}
// Current character index and its past-the-end byte in the sentence.
int32 char_index = 0;
int32 char_end_byte = start_byte + UniLib::OneCharLen(data + start_byte);
// Current token index and its byte index.
int32 token_index = 0;
int32 token_byte = GetFocusByte(focus, sentence.token(0));
// Scan through the characters and tokens. For each token, we assign it the
// character whose byte range contains its focus byte.
while (true) {
// If the character ends after the token, then the token must lie within the
// character, or we would have consumed the token in a previous iteration.
if (char_end_byte > token_byte) {
(*links)[token_index] =
Translate(translator, last_char_index, char_index);
if (++token_index >= num_tokens) break;
token_byte = GetFocusByte(focus, sentence.token(token_index));
} else if (char_end_byte < end_byte) {
++char_index;
char_end_byte += UniLib::OneCharLen(data + char_end_byte);
} else {
break;
}
}
if (char_end_byte > end_byte) {
return tensorflow::errors::InvalidArgument(
"Last token ends in the middle of a UTF-8 character: ",
sentence.token(num_tokens - 1).ShortDebugString());
}
// Since GetFocusByte() always returns a byte index within the span of the
// token, the loop above must consume all tokens.
DCHECK_EQ(token_index, num_tokens);
return tensorflow::Status::OK();
}
using SyntaxNetFirstCharacterIdentitySequenceLinker =
SyntaxNetCharacterSequenceLinker<Focus::kFirst, Translator::kIdentity>;
using SyntaxNetFirstCharacterReversedSequenceLinker =
SyntaxNetCharacterSequenceLinker<Focus::kFirst, Translator::kReversed>;
using SyntaxNetLastCharacterIdentitySequenceLinker =
SyntaxNetCharacterSequenceLinker<Focus::kLast, Translator::kIdentity>;
using SyntaxNetLastCharacterReversedSequenceLinker =
SyntaxNetCharacterSequenceLinker<Focus::kLast, Translator::kReversed>;
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(
SyntaxNetFirstCharacterIdentitySequenceLinker);
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(
SyntaxNetFirstCharacterReversedSequenceLinker);
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(
SyntaxNetLastCharacterIdentitySequenceLinker);
DRAGNN_RUNTIME_REGISTER_SEQUENCE_LINKER(
SyntaxNetLastCharacterReversedSequenceLinker);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/sequence_linker.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::ElementsAre;
// Returns a ComponentSpec parsed from the |text|.
ComponentSpec ParseSpec(const string &text) {
ComponentSpec component_spec;
CHECK(TextFormat::ParseFromString(text, &component_spec));
return component_spec;
}
// Returns a ComponentSpec that some linker supports.
ComponentSpec MakeSupportedSpec() {
return ParseSpec(R"(
transition_system { registered_name:'shift-only' }
backend { registered_name:'SyntaxNetComponent' }
linked_feature { fml:'input.first-char-focus' source_translator:'identity' }
)");
}
// Returns a Sentence parsed from the |text|.
Sentence ParseSentence(const string &text) {
Sentence sentence;
CHECK(TextFormat::ParseFromString(text, &sentence));
return sentence;
}
// Returns a default sentence.
Sentence MakeSentence() {
return ParseSentence(R"(
text:'012345678901234567890123456789人1工神2经网¢络'
token { start:30 end:36 word:'人1工' }
token { start:37 end:43 word:'神2经' }
token { start:44 end:51 word:'网¢络' }
)");
}
// Number of UTF-8 characters in the default sentence.
constexpr int kNumChars = 9;
// Tests that the linker supports appropriate specs.
TEST(SyntaxNetCharacterSequenceLinkersTest, Supported) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "SyntaxNetFirstCharacterIdentitySequenceLinker");
channel.set_source_translator("reverse-char");
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "SyntaxNetFirstCharacterReversedSequenceLinker");
channel.set_fml("input.last-char-focus");
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "SyntaxNetLastCharacterReversedSequenceLinker");
channel.set_source_translator("identity");
TF_ASSERT_OK(SequenceLinker::Select(channel, component_spec, &name));
EXPECT_EQ(name, "SyntaxNetLastCharacterIdentitySequenceLinker");
}
// Tests that the linker requires the right transition system.
TEST(SyntaxNetCharacterSequenceLinkersTest, WrongTransitionSystem) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
component_spec.mutable_backend()->set_registered_name("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right FML.
TEST(SyntaxNetCharacterSequenceLinkersTest, WrongFml) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_fml("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right translator.
TEST(SyntaxNetCharacterSequenceLinkersTest, WrongTranslator) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
LinkedFeatureChannel &channel = *component_spec.mutable_linked_feature(0);
channel.set_source_translator("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Tests that the linker requires the right backend.
TEST(SyntaxNetCharacterSequenceLinkersTest, WrongBackend) {
string name;
ComponentSpec component_spec = MakeSupportedSpec();
const LinkedFeatureChannel &channel = component_spec.linked_feature(0);
component_spec.mutable_backend()->set_registered_name("bad");
EXPECT_THAT(SequenceLinker::Select(channel, component_spec, &name),
test::IsErrorWithSubstr("No SequenceLinker supports channel"));
}
// Rig for testing GetLinks().
class SyntaxNetCharacterSequenceLinkersGetLinksTest : public ::testing::Test {
protected:
void SetUp() override {
// Initialize() doesn't look at the channel or spec, so use empty protos.
const ComponentSpec component_spec;
const LinkedFeatureChannel channel;
TF_ASSERT_OK(
SequenceLinker::New("SyntaxNetFirstCharacterIdentitySequenceLinker",
channel, component_spec, &first_identity_));
TF_ASSERT_OK(
SequenceLinker::New("SyntaxNetFirstCharacterReversedSequenceLinker",
channel, component_spec, &first_reversed_));
TF_ASSERT_OK(
SequenceLinker::New("SyntaxNetLastCharacterIdentitySequenceLinker",
channel, component_spec, &last_identity_));
TF_ASSERT_OK(
SequenceLinker::New("SyntaxNetLastCharacterReversedSequenceLinker",
channel, component_spec, &last_reversed_));
}
// Linkers in all four configurations.
std::unique_ptr<SequenceLinker> first_identity_;
std::unique_ptr<SequenceLinker> first_reversed_;
std::unique_ptr<SequenceLinker> last_identity_;
std::unique_ptr<SequenceLinker> last_reversed_;
};
// Tests that the linkers can extract links from the default sentence.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest, DefaultSentence) {
const Sentence sentence = MakeSentence();
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> links = {123, 456, 789}; // gets overwritten
TF_ASSERT_OK(first_identity_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(0, 3, 6));
TF_ASSERT_OK(first_reversed_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(8, 5, 2));
TF_ASSERT_OK(last_identity_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(2, 5, 8));
TF_ASSERT_OK(last_reversed_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(6, 3, 0));
}
// Tests that the linkers can handle an empty sentence.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest, EmptySentence) {
const Sentence sentence;
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> links;
TF_ASSERT_OK(first_identity_->GetLinks(kNumChars, &input, &links));
TF_ASSERT_OK(first_reversed_->GetLinks(kNumChars, &input, &links));
TF_ASSERT_OK(last_identity_->GetLinks(kNumChars, &input, &links));
TF_ASSERT_OK(last_reversed_->GetLinks(kNumChars, &input, &links));
}
// Tests that the linkers fail if the batch is not a singleton.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest, NonSingleton) {
const Sentence sentence = MakeSentence();
const std::vector<string> data = {sentence.SerializeAsString(),
sentence.SerializeAsString()};
InputBatchCache input(data);
std::vector<int32> links;
EXPECT_THAT(first_identity_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr("Non-singleton batch: got 2 elements"));
EXPECT_THAT(first_reversed_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr("Non-singleton batch: got 2 elements"));
EXPECT_THAT(last_identity_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr("Non-singleton batch: got 2 elements"));
EXPECT_THAT(last_reversed_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr("Non-singleton batch: got 2 elements"));
}
// Tests that the linkers fail if the first token starts in the middle of a
// UTF-8 character.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest, FirstTokenStartsWrong) {
Sentence sentence = MakeSentence();
sentence.mutable_token(0)->set_start(31);
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> links;
EXPECT_THAT(first_identity_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"First token starts in the middle of a UTF-8 character"));
EXPECT_THAT(first_reversed_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"First token starts in the middle of a UTF-8 character"));
EXPECT_THAT(last_identity_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"First token starts in the middle of a UTF-8 character"));
EXPECT_THAT(last_reversed_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"First token starts in the middle of a UTF-8 character"));
}
// Tests that the linkers fail if the last token ends in the middle of a UTF-8
// character.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest, LastTokenEndsWrong) {
Sentence sentence = MakeSentence();
sentence.mutable_token(2)->set_end(45);
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> links;
EXPECT_THAT(first_identity_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"Last token ends in the middle of a UTF-8 character"));
EXPECT_THAT(first_reversed_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"Last token ends in the middle of a UTF-8 character"));
EXPECT_THAT(last_identity_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"Last token ends in the middle of a UTF-8 character"));
EXPECT_THAT(last_reversed_->GetLinks(kNumChars, &input, &links),
test::IsErrorWithSubstr(
"Last token ends in the middle of a UTF-8 character"));
}
// Tests that the linkers can tolerate a sentence where the interior token byte
// offsets are wrong.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest,
InteriorTokenBoundariesSlightlyWrong) {
Sentence sentence = MakeSentence();
sentence.mutable_token(0)->set_end(35);
sentence.mutable_token(1)->set_start(38);
sentence.mutable_token(1)->set_end(42);
sentence.mutable_token(2)->set_start(45);
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> links;
// The results should be the same as in the default sentence.
TF_ASSERT_OK(first_identity_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(0, 3, 6));
TF_ASSERT_OK(first_reversed_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(8, 5, 2));
TF_ASSERT_OK(last_identity_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(2, 5, 8));
TF_ASSERT_OK(last_reversed_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(6, 3, 0));
}
// As above, but places the token boundaries even further off.
TEST_F(SyntaxNetCharacterSequenceLinkersGetLinksTest,
InteriorTokenBoundariesMostlyWrong) {
Sentence sentence = MakeSentence();
sentence.mutable_token(0)->set_end(34);
sentence.mutable_token(1)->set_start(39);
sentence.mutable_token(1)->set_end(41);
sentence.mutable_token(2)->set_start(46);
InputBatchCache input(sentence.SerializeAsString());
std::vector<int32> links;
// The results should be the same as in the default sentence.
TF_ASSERT_OK(first_identity_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(0, 3, 6));
TF_ASSERT_OK(first_reversed_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(8, 5, 2));
TF_ASSERT_OK(last_identity_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(2, 5, 8));
TF_ASSERT_OK(last_reversed_->GetLinks(kNumChars, &input, &links));
EXPECT_THAT(links, ElementsAre(6, 3, 0));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/head_selection_component_base.h"
#include "dragnn/runtime/session_state.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Selects heads for SyntaxNetComponent batches.
class SyntaxNetHeadSelectionComponent : public HeadSelectionComponentBase {
public:
SyntaxNetHeadSelectionComponent()
: HeadSelectionComponentBase("SyntaxNetHeadSelectionComponent",
"SyntaxNetComponent") {}
// Implements Component.
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override;
};
tensorflow::Status SyntaxNetHeadSelectionComponent::Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const {
InputBatchCache *input = compute_session->GetInputBatchCache();
if (input == nullptr) {
return tensorflow::errors::InvalidArgument("Null input batch");
}
const std::vector<SyntaxNetSentence> &data =
*input->GetAs<SentenceInputBatch>()->data();
if (data.size() != 1) {
return tensorflow::errors::InvalidArgument("Non-singleton batch: got ",
data.size(), " elements");
}
const std::vector<int> &heads = ComputeHeads(session_state);
Sentence *sentence = data[0].sentence();
if (heads.size() != sentence->token_size()) {
return tensorflow::errors::InvalidArgument(
"Sentence size mismatch: expected ", heads.size(), " tokens but got ",
sentence->token_size());
}
int token_index = 0;
for (const int head : heads) {
Token *token = sentence->mutable_token(token_index++);
if (head == -1) {
token->clear_head();
} else {
token->set_head(head);
}
}
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT(SyntaxNetHeadSelectionComponent);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/sentence.pb.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::Return;
constexpr char kPreviousComponentName[] = "previous_component";
constexpr char kAdjacencyLayerName[] = "adjacency_layer";
// Returns a ComponentSpec that works with the head selection component.
ComponentSpec MakeGoodSpec() {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name(
"SyntaxNetHeadSelectionComponent");
component_spec.mutable_backend()->set_registered_name("SyntaxNetComponent");
component_spec.mutable_transition_system()->set_registered_name("heads");
component_spec.mutable_network_unit()->set_registered_name("IdentityNetwork");
LinkedFeatureChannel *link = component_spec.add_linked_feature();
link->set_source_component(kPreviousComponentName);
link->set_source_layer(kAdjacencyLayerName);
return component_spec;
}
// Returns a sentence containing |num_tokens| tokens. All heads are set to
// self-loops, which are normally invalid, to ensure that the head selector
// touches all tokens.
Sentence MakeSentence(int num_tokens) {
Sentence sentence;
for (int i = 0; i < num_tokens; ++i) {
Token *token = sentence.add_token();
token->set_start(0); // never used; set because required field
token->set_end(0); // never used; set because required field
token->set_word("foo"); // never used; set because required field
token->set_head(i);
}
return sentence;
}
class SyntaxNetHeadSelectionComponentTest : public NetworkTestBase {
protected:
// Initializes a parser head selection component from the |component_spec|,
// feeds it the |adjacency| matrix, and applies the resulting heads to the
// |sentence|. Returs non-OK on error.
tensorflow::Status Run(const ComponentSpec &component_spec,
const std::vector<std::vector<float>> &adjacency,
Sentence *sentence) {
AddComponent(kPreviousComponentName);
AddPairwiseLayer(kAdjacencyLayerName, 1);
std::unique_ptr<Component> component;
TF_RETURN_IF_ERROR(Component::CreateOrError(
"SyntaxNetHeadSelectionComponent", &component));
TF_RETURN_IF_ERROR(component->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
network_states_.Reset(&network_state_manager_);
const int num_steps = adjacency.size();
StartComponent(num_steps);
MutableMatrix<float> adjacency_layer =
GetPairwiseLayer(kPreviousComponentName, kAdjacencyLayerName);
for (size_t target = 0; target < num_steps; ++target) {
for (size_t source = 0; source < num_steps; ++source) {
adjacency_layer.row(target)[source] = adjacency[target][source];
}
}
string data;
CHECK(sentence->SerializeToString(&data));
InputBatchCache input(data);
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input));
session_state_.extensions.Reset(&extension_manager_);
TF_RETURN_IF_ERROR(
component->Evaluate(&session_state_, &compute_session_, nullptr));
CHECK(sentence->ParseFromString(input.SerializedData()[0]));
return tensorflow::Status::OK();
}
};
// Tests the head selector on a single-token input.
TEST_F(SyntaxNetHeadSelectionComponentTest, ParseOneToken) {
const std::vector<std::vector<float>> adjacency = {{0.0}};
Sentence sentence = MakeSentence(1);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_FALSE(sentence.token(0).has_head());
}
// Tests the head selector on a two-token input.
TEST_F(SyntaxNetHeadSelectionComponentTest, ParseTwoTokens) {
// This adjacency matrix forms a cycle, not a tree, but it doesn't matter
// since the head selector is unstructured.
const std::vector<std::vector<float>> adjacency = {{0.0, 1.0}, //
{1.0, 0.0}};
Sentence sentence = MakeSentence(2);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_EQ(sentence.token(0).head(), 1);
EXPECT_EQ(sentence.token(1).head(), 0);
}
// Tests the head selector on a three-token input.
TEST_F(SyntaxNetHeadSelectionComponentTest, ParseThreeTokens) {
// This adjacency matrix forms a left-headed chain.
const std::vector<std::vector<float>> adjacency = {{1.0, 0.0, 0.0}, //
{1.0, 0.0, 0.0}, //
{0.0, 1.0, 0.0}};
Sentence sentence = MakeSentence(3);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_FALSE(sentence.token(0).has_head());
EXPECT_EQ(sentence.token(1).head(), 0);
EXPECT_EQ(sentence.token(2).head(), 1);
}
// Tests the head selector on a four-token input.
TEST_F(SyntaxNetHeadSelectionComponentTest, ParseFourTokens) {
// This adjacency matrix forms a right-headed chain.
const std::vector<std::vector<float>> adjacency = {{0.0, 1.0, 0.0, 0.0}, //
{0.0, 0.0, 1.0, 0.0}, //
{0.0, 0.0, 0.0, 1.0}, //
{0.0, 0.0, 0.0, 1.0}};
Sentence sentence = MakeSentence(4);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_EQ(sentence.token(0).head(), 1);
EXPECT_EQ(sentence.token(1).head(), 2);
EXPECT_EQ(sentence.token(2).head(), 3);
EXPECT_FALSE(sentence.token(3).has_head());
}
// Tests that the component supports the good spec.
TEST_F(SyntaxNetHeadSelectionComponentTest, Supported) {
const ComponentSpec component_spec = MakeGoodSpec();
string name;
TF_ASSERT_OK(Component::Select(component_spec, &name));
EXPECT_EQ(name, "SyntaxNetHeadSelectionComponent");
}
// Tests that the component requires the proper backend.
TEST_F(SyntaxNetHeadSelectionComponentTest, WrongComponentBuilder) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_component_builder()->set_registered_name("bad");
string name;
EXPECT_THAT(
Component::Select(component_spec, &name),
test::IsErrorWithSubstr("Could not find a best spec for component"));
}
// Tests that the component requires the proper backend.
TEST_F(SyntaxNetHeadSelectionComponentTest, WrongBackend) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_backend()->set_registered_name("bad");
string name;
EXPECT_THAT(
Component::Select(component_spec, &name),
test::IsErrorWithSubstr("Could not find a best spec for component"));
}
// Tests that Evaluate() fails if the batch is null.
TEST_F(SyntaxNetHeadSelectionComponentTest, NullBatch) {
std::unique_ptr<Component> component;
TF_ASSERT_OK(
Component::CreateOrError("SyntaxNetHeadSelectionComponent", &component));
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(nullptr));
EXPECT_THAT(component->Evaluate(&session_state_, &compute_session_, nullptr),
test::IsErrorWithSubstr("Null input batch"));
}
// Tests that Evaluate() fails if the batch is the wrong size.
TEST_F(SyntaxNetHeadSelectionComponentTest, WrongBatchSize) {
std::unique_ptr<Component> component;
TF_ASSERT_OK(
Component::CreateOrError("SyntaxNetHeadSelectionComponent", &component));
InputBatchCache input({MakeSentence(1).SerializeAsString(),
MakeSentence(2).SerializeAsString(),
MakeSentence(3).SerializeAsString(),
MakeSentence(4).SerializeAsString()});
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input));
EXPECT_THAT(component->Evaluate(&session_state_, &compute_session_, nullptr),
test::IsErrorWithSubstr("Non-singleton batch: got 4 elements"));
}
// Tests that Evaluate() fails if the adjacency matrix and sentence disagree on
// the number of tokens.
TEST_F(SyntaxNetHeadSelectionComponentTest, WrongNumTokens) {
const std::vector<std::vector<float>> adjacency = {{1.0, 0.0, 0.0, 0.0}, //
{0.0, 1.0, 0.0, 0.0}, //
{0.0, 0.0, 1.0, 0.0}, //
{0.0, 0.0, 0.0, 1.0}};
// 4-token adjacency matrix with 3-token sentence.
Sentence sentence = MakeSentence(3);
EXPECT_THAT(Run(MakeGoodSpec(), adjacency, &sentence),
test::IsErrorWithSubstr(
"Sentence size mismatch: expected 4 tokens but got 3"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/trace.pb.h"
#include "dragnn/runtime/mst_solver_component_base.h"
#include "dragnn/runtime/session_state.h"
#include "syntaxnet/base.h"
#include "syntaxnet/sentence.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Selects heads for SyntaxNetComponent batches.
class SyntaxNetMstSolverComponent : public MstSolverComponentBase {
public:
SyntaxNetMstSolverComponent()
: MstSolverComponentBase("SyntaxNetMstSolverComponent",
"SyntaxNetComponent") {}
// Implements Component.
tensorflow::Status Evaluate(SessionState *session_state,
ComputeSession *compute_session,
ComponentTrace *component_trace) const override;
};
tensorflow::Status SyntaxNetMstSolverComponent::Evaluate(
SessionState *session_state, ComputeSession *compute_session,
ComponentTrace *component_trace) const {
InputBatchCache *input = compute_session->GetInputBatchCache();
if (input == nullptr) {
return tensorflow::errors::InvalidArgument("Null input batch");
}
const std::vector<SyntaxNetSentence> &data =
*input->GetAs<SentenceInputBatch>()->data();
if (data.size() != 1) {
return tensorflow::errors::InvalidArgument("Non-singleton batch: got ",
data.size(), " elements");
}
tensorflow::gtl::ArraySlice<Index> heads;
TF_RETURN_IF_ERROR(ComputeHeads(session_state, &heads));
Sentence *sentence = data[0].sentence();
if (heads.size() != sentence->token_size()) {
return tensorflow::errors::InvalidArgument(
"Sentence size mismatch: expected ", heads.size(), " tokens but got ",
sentence->token_size());
}
const int num_tokens = heads.size();
for (int modifier = 0; modifier < num_tokens; ++modifier) {
Token *token = sentence->mutable_token(modifier);
const int head = heads[modifier];
if (head == modifier) {
token->clear_head();
} else {
token->set_head(head);
}
}
return tensorflow::Status::OK();
}
DRAGNN_RUNTIME_REGISTER_COMPONENT(SyntaxNetMstSolverComponent);
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <memory>
#include <string>
#include <vector>
#include "dragnn/core/compute_session.h"
#include "dragnn/core/input_batch_cache.h"
#include "dragnn/core/test/generic.h"
#include "dragnn/io/sentence_input_batch.h"
#include "dragnn/io/syntaxnet_sentence.h"
#include "dragnn/protos/spec.pb.h"
#include "dragnn/runtime/component.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/test/network_test_base.h"
#include "dragnn/runtime/variable_store.h"
#include "syntaxnet/sentence.pb.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
using ::testing::Return;
constexpr char kPreviousComponentName[] = "previous_component";
constexpr char kAdjacencyLayerName[] = "adjacency_layer";
// Returns a ComponentSpec that works with the head selection component.
ComponentSpec MakeGoodSpec() {
ComponentSpec component_spec;
component_spec.mutable_component_builder()->set_registered_name(
"SyntaxNetMstSolverComponent");
component_spec.mutable_backend()->set_registered_name("SyntaxNetComponent");
component_spec.mutable_transition_system()->set_registered_name("heads");
component_spec.mutable_network_unit()->set_registered_name(
"some.path.to.MstSolverNetwork");
LinkedFeatureChannel *link = component_spec.add_linked_feature();
link->set_source_component(kPreviousComponentName);
link->set_source_layer(kAdjacencyLayerName);
return component_spec;
}
// Returns a sentence containing |num_tokens| tokens. All heads are set to
// self-loops, which are normally invalid, to ensure that the head selector
// touches all tokens.
Sentence MakeSentence(int num_tokens) {
Sentence sentence;
for (int i = 0; i < num_tokens; ++i) {
Token *token = sentence.add_token();
token->set_start(0); // never used; set because required field
token->set_end(0); // never used; set because required field
token->set_word("foo"); // never used; set because required field
token->set_head(i);
}
return sentence;
}
class SyntaxNetMstSolverComponentTest : public NetworkTestBase {
protected:
// Initializes a parser head selection component from the |component_spec|,
// feeds it the |adjacency| matrix, and applies the resulting heads to the
// |sentence|. Returs non-OK on error.
tensorflow::Status Run(const ComponentSpec &component_spec,
const std::vector<std::vector<float>> &adjacency,
Sentence *sentence) {
AddComponent(kPreviousComponentName);
AddPairwiseLayer(kAdjacencyLayerName, 1);
std::unique_ptr<Component> component;
TF_RETURN_IF_ERROR(Component::CreateOrError(
"SyntaxNetMstSolverComponent", &component));
TF_RETURN_IF_ERROR(component->Initialize(component_spec, &variable_store_,
&network_state_manager_,
&extension_manager_));
network_states_.Reset(&network_state_manager_);
const int num_steps = adjacency.size();
StartComponent(num_steps);
MutableMatrix<float> adjacency_layer =
GetPairwiseLayer(kPreviousComponentName, kAdjacencyLayerName);
for (size_t target = 0; target < num_steps; ++target) {
for (size_t source = 0; source < num_steps; ++source) {
adjacency_layer.row(target)[source] = adjacency[target][source];
}
}
string data;
CHECK(sentence->SerializeToString(&data));
InputBatchCache input(data);
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input));
session_state_.extensions.Reset(&extension_manager_);
TF_RETURN_IF_ERROR(
component->Evaluate(&session_state_, &compute_session_, nullptr));
CHECK(sentence->ParseFromString(input.SerializedData()[0]));
return tensorflow::Status::OK();
}
};
// Tests the head selector on a single-token input.
TEST_F(SyntaxNetMstSolverComponentTest, ParseOneToken) {
const std::vector<std::vector<float>> adjacency = {{0.0}};
Sentence sentence = MakeSentence(1);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_FALSE(sentence.token(0).has_head());
}
// Tests the head selector on a two-token input.
TEST_F(SyntaxNetMstSolverComponentTest, ParseTwoTokens) {
const std::vector<std::vector<float>> adjacency = {{0.0, 1.0}, //
{0.9, 1.0}};
Sentence sentence = MakeSentence(2);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_EQ(sentence.token(0).head(), 1);
EXPECT_EQ(sentence.token(1).head(), -1);
}
// Tests the head selector on a three-token input.
TEST_F(SyntaxNetMstSolverComponentTest, ParseThreeTokens) {
// This adjacency matrix forms a left-headed chain.
const std::vector<std::vector<float>> adjacency = {{1.0, 0.0, 0.0}, //
{1.0, 0.0, 0.0}, //
{0.0, 1.0, 0.0}};
Sentence sentence = MakeSentence(3);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_FALSE(sentence.token(0).has_head());
EXPECT_EQ(sentence.token(1).head(), 0);
EXPECT_EQ(sentence.token(2).head(), 1);
}
// Tests the head selector on a four-token input.
TEST_F(SyntaxNetMstSolverComponentTest, ParseFourTokens) {
// This adjacency matrix forms a right-headed chain.
const std::vector<std::vector<float>> adjacency = {{0.0, 1.0, 0.0, 0.0}, //
{0.0, 0.0, 1.0, 0.0}, //
{0.0, 0.0, 0.0, 1.0}, //
{0.0, 0.0, 0.0, 1.0}};
Sentence sentence = MakeSentence(4);
TF_ASSERT_OK(Run(MakeGoodSpec(), adjacency, &sentence));
EXPECT_EQ(sentence.token(0).head(), 1);
EXPECT_EQ(sentence.token(1).head(), 2);
EXPECT_EQ(sentence.token(2).head(), 3);
EXPECT_FALSE(sentence.token(3).has_head());
}
// Tests that the component supports the good spec.
TEST_F(SyntaxNetMstSolverComponentTest, Supported) {
const ComponentSpec component_spec = MakeGoodSpec();
string name;
TF_ASSERT_OK(Component::Select(component_spec, &name));
EXPECT_EQ(name, "SyntaxNetMstSolverComponent");
}
// Tests that the component requires the proper backend.
TEST_F(SyntaxNetMstSolverComponentTest, WrongComponentBuilder) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_component_builder()->set_registered_name("bad");
string name;
EXPECT_THAT(
Component::Select(component_spec, &name),
test::IsErrorWithSubstr("Could not find a best spec for component"));
}
// Tests that the component requires the proper backend.
TEST_F(SyntaxNetMstSolverComponentTest, WrongBackend) {
ComponentSpec component_spec = MakeGoodSpec();
component_spec.mutable_backend()->set_registered_name("bad");
string name;
EXPECT_THAT(
Component::Select(component_spec, &name),
test::IsErrorWithSubstr("Could not find a best spec for component"));
}
// Tests that Evaluate() fails if the batch is null.
TEST_F(SyntaxNetMstSolverComponentTest, NullBatch) {
std::unique_ptr<Component> component;
TF_ASSERT_OK(
Component::CreateOrError("SyntaxNetMstSolverComponent", &component));
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(nullptr));
EXPECT_THAT(component->Evaluate(&session_state_, &compute_session_, nullptr),
test::IsErrorWithSubstr("Null input batch"));
}
// Tests that Evaluate() fails if the batch is the wrong size.
TEST_F(SyntaxNetMstSolverComponentTest, WrongBatchSize) {
std::unique_ptr<Component> component;
TF_ASSERT_OK(
Component::CreateOrError("SyntaxNetMstSolverComponent", &component));
InputBatchCache input({MakeSentence(1).SerializeAsString(),
MakeSentence(2).SerializeAsString(),
MakeSentence(3).SerializeAsString(),
MakeSentence(4).SerializeAsString()});
EXPECT_CALL(compute_session_, GetInputBatchCache())
.WillRepeatedly(Return(&input));
EXPECT_THAT(component->Evaluate(&session_state_, &compute_session_, nullptr),
test::IsErrorWithSubstr("Non-singleton batch: got 4 elements"));
}
// Tests that Evaluate() fails if the adjacency matrix and sentence disagree on
// the number of tokens.
TEST_F(SyntaxNetMstSolverComponentTest, WrongNumTokens) {
const std::vector<std::vector<float>> adjacency = {{1.0, 0.0, 0.0, 0.0}, //
{0.0, 1.0, 0.0, 0.0}, //
{0.0, 0.0, 1.0, 0.0}, //
{0.0, 0.0, 0.0, 1.0}};
// 4-token adjacency matrix with 3-token sentence.
Sentence sentence = MakeSentence(3);
EXPECT_THAT(Run(MakeGoodSpec(), adjacency, &sentence),
test::IsErrorWithSubstr(
"Sentence size mismatch: expected 4 tokens but got 3"));
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment