Unverified Commit 80178fc6 authored by Mark Omernick's avatar Mark Omernick Committed by GitHub
Browse files

Merge pull request #4153 from terryykoo/master

Export @195097388.
parents a84e1ef9 edea2b67
......@@ -62,13 +62,14 @@ class MockComputeSession : public ComputeSession {
MOCK_METHOD2(GetTranslatedLinkFeatures,
std::vector<LinkFeatures>(const string &component_name,
int channel_id));
MOCK_METHOD1(EmitOracleLabels,
std::vector<std::vector<int>>(const string &component_name));
MOCK_METHOD1(EmitOracleLabels, std::vector<std::vector<std::vector<Label>>>(
const string &component_name));
MOCK_METHOD1(IsTerminal, bool(const string &component_name));
MOCK_METHOD1(FinalizeData, void(const string &component_name));
MOCK_METHOD0(GetSerializedPredictions, std::vector<string>());
MOCK_METHOD0(GetTraceProtos, std::vector<MasterTrace>());
MOCK_METHOD1(SetInputData, void(const std::vector<string> &data));
MOCK_METHOD0(GetInputBatchCache, InputBatchCache *());
MOCK_METHOD0(ResetSession, void());
MOCK_METHOD1(SetTracing, void(bool tracing_on));
MOCK_CONST_METHOD0(Id, int());
......
package(
default_visibility = ["//visibility:public"],
features = ["-layering_check"],
)
cc_library(
name = "label",
hdrs = ["label.h"],
)
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_CORE_UTIL_LABEL_H_
#define DRAGNN_CORE_UTIL_LABEL_H_
#include <cmath>
namespace syntaxnet {
namespace dragnn {
// Stores label information.
struct Label {
Label(int label_id, float label_probability)
: id(label_id), probability(label_probability) {}
explicit Label(int label_id) : id(label_id) {}
// Two Labels are equal if the ids match and the probabilities are within an
// epsilon of one another.
bool operator==(const Label &label) const {
return (id == label.id) &&
std::fabs(probability - label.probability) < 0.00001;
}
// Label id and probability.
int id;
float probability = 1.0;
};
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_CORE_UTIL_LABEL_H_
......@@ -8,7 +8,7 @@ cc_library(
":syntaxnet_sentence",
"//dragnn/core/interfaces:input_batch",
"//syntaxnet:base",
"//syntaxnet:sentence_proto",
"//syntaxnet:sentence_proto_cc",
],
)
......@@ -16,7 +16,7 @@ cc_library(
name = "syntaxnet_sentence",
hdrs = ["syntaxnet_sentence.h"],
deps = [
"//syntaxnet:sentence_proto",
"//syntaxnet:sentence_proto_cc",
"//syntaxnet:workspace",
],
)
......@@ -27,7 +27,7 @@ cc_test(
deps = [
":sentence_input_batch",
"//dragnn/core/test:generic",
"//syntaxnet:sentence_proto",
"//syntaxnet:sentence_proto_cc",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:test",
],
......
package(default_visibility = ["//visibility:public"])
cc_library(
name = "disjoint_set_forest",
hdrs = ["disjoint_set_forest.h"],
deps = [
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "disjoint_set_forest_test",
size = "small",
srcs = ["disjoint_set_forest_test.cc"],
deps = [
":disjoint_set_forest",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "spanning_tree_iterator",
testonly = 1,
srcs = ["spanning_tree_iterator.cc"],
hdrs = ["spanning_tree_iterator.h"],
deps = [
"//syntaxnet:base",
],
)
cc_test(
name = "spanning_tree_iterator_test",
size = "small",
srcs = ["spanning_tree_iterator_test.cc"],
deps = [
":spanning_tree_iterator",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "mst_solver",
hdrs = ["mst_solver.h"],
deps = [
":disjoint_set_forest",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "mst_solver_test",
size = "small",
srcs = ["mst_solver_test.cc"],
deps = [
":mst_solver",
"//dragnn/core/test:generic",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_test(
name = "mst_solver_random_comparison_test",
size = "small",
timeout = "long",
srcs = ["mst_solver_random_comparison_test.cc"],
tags = [
"manual", # exclude from :all, since this is expensive
],
deps = [
":mst_solver",
":spanning_tree_iterator",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
load(
"@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
)
tf_gen_op_libs(
op_lib_names = ["mst_ops"],
)
# Don't use this library directly; instead use "dragnn/python:mst_ops".
tf_gen_op_wrapper_py(
name = "mst_ops",
visibility = ["//dragnn/python:__pkg__"],
deps = [":mst_ops_op_lib"],
)
cc_library(
name = "mst_ops_cc",
srcs = [
"ops/mst_op_kernels.cc",
"ops/mst_ops.cc",
],
deps = [
":mst_solver",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
Package for solving max-spanning-tree (MST) problems. The code here is intended
for NLP applications, but attempts to remain agnostic to particular NLP tasks
(such as dependency parsing).
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_MST_DISJOINT_SET_FOREST_H_
#define DRAGNN_MST_DISJOINT_SET_FOREST_H_
#include <stddef.h>
#include <type_traits>
#include <vector>
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
// An implementation of the disjoint-set forest data structure. The universe of
// elements is the dense range of indices [0,n). Thread-compatible.
//
// By default, this uses the path compression and union by rank optimizations,
// achieving near-constant runtime on all operations. However, the user may
// disable the union by rank optimization, which allows the user to control how
// roots are selected when a union occurs. When union by rank is disabled, the
// runtime of all operations increases to O(log n) amortized.
//
// Template args:
// Index: An unsigned integral type wide enough to hold n.
// kUseUnionByRank: Whether to use the union by rank optimization.
template <class Index, bool kUseUnionByRank = true>
class DisjointSetForest {
public:
static_assert(std::is_integral<Index>::value, "Index must be integral");
static_assert(!std::is_signed<Index>::value, "Index must be unsigned");
using IndexType = Index;
// Creates an empty forest.
DisjointSetForest() = default;
// Initializes this to hold the elements [0,|size|), each initially in its own
// singleton set. Replaces existing state, if any.
void Init(Index size);
// Returns the root of the set containing |element|, which uniquely identifies
// the set. Note that the root of a set may change as the set is merged with
// other sets; do not cache the return value of FindRoot(e) across calls to
// Union() or UnionOfRoots() that could merge the set containing e.
Index FindRoot(Index element);
// For convenience, returns true if |element1| and |element2| are in the same
// set. When performing a large batch of queries it may be more efficient to
// cache the value of FindRoot(), modulo caveats regarding caching above.
bool SameSet(Index element1, Index element2);
// Merges the sets rooted at |root1| and |root2|, which must be the roots of
// their respective sets. Either |root1| or |root2| will be the root of the
// merged set. If |kUseUnionByRank| is true, then it is unspecified whether
// |root1| or |root2| will be the root; otherwise, |root2| will be the root.
void UnionOfRoots(Index root1, Index root2);
// As above, but for convenience finds the root of |element1| and |element2|.
void Union(Index element1, Index element2);
// The number of elements in this.
Index size() const { return size_; }
private:
// The number of elements in the universe underlying the sets.
Index size_ = 0;
// The parent of each element, where self-loops are roots.
std::vector<Index> parents_;
// The rank of each element, for the union by rank optimization. Only used if
// |kUseUnionByRank| is true.
std::vector<Index> ranks_;
};
// Implementation details below.
template <class Index, bool kUseUnionByRank>
void DisjointSetForest<Index, kUseUnionByRank>::Init(Index size) {
size_ = size;
parents_.resize(size_);
if (kUseUnionByRank) ranks_.resize(size_);
// Create singleton sets.
for (Index i = 0; i < size_; ++i) {
parents_[i] = i;
if (kUseUnionByRank) ranks_[i] = 0;
}
}
template <class Index, bool kUseUnionByRank>
Index DisjointSetForest<Index, kUseUnionByRank>::FindRoot(Index element) {
DCHECK_LT(element, size());
Index *const __restrict parents = parents_.data();
// Walk up to the root of the |element|. Unroll the first two comparisons
// because path compression ensures most FindRoot() calls end there. In
// addition, if a root is found within the first two comparisons, then the
// path compression updates can be skipped.
Index current = element;
Index parent = parents[current];
if (current == parent) return current; // |element| is a root
current = parent;
parent = parents[current];
if (current == parent) return current; // |element| is the child of a root
do { // otherwise, continue upwards until root
current = parent;
parent = parents[current];
} while (current != parent);
const Index root = current;
// Apply path compression on the traversed nodes.
current = element;
parent = parents[current]; // not root, thanks to unrolling above
do {
parents[current] = root;
current = parent;
parent = parents[current];
} while (parent != root);
return root;
}
template <class Index, bool kUseUnionByRank>
bool DisjointSetForest<Index, kUseUnionByRank>::SameSet(Index element1,
Index element2) {
return FindRoot(element1) == FindRoot(element2);
}
template <class Index, bool kUseUnionByRank>
void DisjointSetForest<Index, kUseUnionByRank>::UnionOfRoots(Index root1,
Index root2) {
DCHECK_LT(root1, size());
DCHECK_LT(root2, size());
DCHECK_EQ(root1, parents_[root1]);
DCHECK_EQ(root2, parents_[root2]);
if (root1 == root2) return; // already merged
Index *const __restrict parents = parents_.data();
if (kUseUnionByRank) {
// Attach the lesser-rank root to the higher-rank root.
Index *const __restrict ranks = ranks_.data();
const Index rank1 = ranks[root1];
const Index rank2 = ranks[root2];
if (rank2 < rank1) {
parents[root2] = root1;
} else if (rank1 < rank2) {
parents[root1] = root2;
} else {
// Equal ranks; choose one arbitrarily and promote its rank.
parents[root1] = root2;
ranks[root2] = rank2 + 1;
}
} else {
// Always make |root2| the root of the merged set.
parents[root1] = root2;
}
}
template <class Index, bool kUseUnionByRank>
void DisjointSetForest<Index, kUseUnionByRank>::Union(Index element1,
Index element2) {
UnionOfRoots(FindRoot(element1), FindRoot(element2));
}
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_MST_DISJOINT_SET_FOREST_H_
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/disjoint_set_forest.h"
#include <stddef.h>
#include <set>
#include <utility>
#include <vector>
#include "syntaxnet/base.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace {
// Testing rig.
//
// Template args:
// Forest: An instantiation of the DisjointSetForest<> template.
template <class Forest>
class DisjointSetForestTest : public ::testing::Test {
protected:
using Index = typename Forest::IndexType;
// Expects that the |expected_sets| and |forest| match.
void ExpectSets(const std::set<std::set<Index>> &expected_sets,
Forest *forest) {
std::set<std::pair<Index, Index>> expected_pairs;
for (const auto &expected_set : expected_sets) {
for (auto it = expected_set.begin(); it != expected_set.end(); ++it) {
for (auto jt = expected_set.begin(); jt != expected_set.end(); ++jt) {
expected_pairs.emplace(*it, *jt);
}
}
}
for (Index lhs = 0; lhs < forest->size(); ++lhs) {
for (Index rhs = 0; rhs < forest->size(); ++rhs) {
if (expected_pairs.find({lhs, rhs}) != expected_pairs.end()) {
EXPECT_EQ(forest->FindRoot(lhs), forest->FindRoot(rhs));
EXPECT_TRUE(forest->SameSet(lhs, rhs));
} else {
EXPECT_NE(forest->FindRoot(lhs), forest->FindRoot(rhs));
EXPECT_FALSE(forest->SameSet(lhs, rhs));
}
}
}
}
};
using Forests = ::testing::Types<
DisjointSetForest<uint8, false>, DisjointSetForest<uint8, true>,
DisjointSetForest<uint16, false>, DisjointSetForest<uint16, true>,
DisjointSetForest<uint32, false>, DisjointSetForest<uint32, true>,
DisjointSetForest<uint64, false>, DisjointSetForest<uint64, true>>;
TYPED_TEST_CASE(DisjointSetForestTest, Forests);
TYPED_TEST(DisjointSetForestTest, DefaultEmpty) {
TypeParam forest;
EXPECT_EQ(0, forest.size());
}
TYPED_TEST(DisjointSetForestTest, InitEmpty) {
TypeParam forest;
forest.Init(0);
EXPECT_EQ(0, forest.size());
}
TYPED_TEST(DisjointSetForestTest, Populated) {
TypeParam forest;
forest.Init(5);
EXPECT_EQ(5, forest.size());
this->ExpectSets({{0}, {1}, {2}, {3}, {4}}, &forest);
forest.UnionOfRoots(1, 2);
this->ExpectSets({{0}, {1, 2}, {3}, {4}}, &forest);
forest.Union(1, 2);
this->ExpectSets({{0}, {1, 2}, {3}, {4}}, &forest);
forest.UnionOfRoots(0, 4);
this->ExpectSets({{0, 4}, {1, 2}, {3}}, &forest);
forest.Union(3, 4);
this->ExpectSets({{0, 3, 4}, {1, 2}}, &forest);
forest.Union(0, 3);
this->ExpectSets({{0, 3, 4}, {1, 2}}, &forest);
forest.Union(2, 0);
this->ExpectSets({{0, 1, 2, 3, 4}}, &forest);
forest.Union(1, 3);
this->ExpectSets({{0, 1, 2, 3, 4}}, &forest);
}
// Testing rig for checking that when union by rank is disabled, the root of a
// merged set can be controlled.
class DisjointSetForestNoUnionByRankTest : public ::testing::Test {
protected:
using Forest = DisjointSetForest<uint32, false>;
// Expects that the roots of the |forest| match |expected_roots|.
void ExpectRoots(const std::vector<uint32> &expected_roots, Forest *forest) {
ASSERT_EQ(expected_roots.size(), forest->size());
for (uint32 i = 0; i < forest->size(); ++i) {
EXPECT_EQ(expected_roots[i], forest->FindRoot(i));
}
}
};
TEST_F(DisjointSetForestNoUnionByRankTest, ManuallySpecifyRoot) {
Forest forest;
forest.Init(5);
ExpectRoots({0, 1, 2, 3, 4}, &forest);
forest.UnionOfRoots(0, 1); // 1 is the root
ExpectRoots({1, 1, 2, 3, 4}, &forest);
forest.Union(4, 3); // 3 is the root
ExpectRoots({1, 1, 2, 3, 3}, &forest);
forest.Union(0, 2); // 2 is the root
ExpectRoots({2, 2, 2, 3, 3}, &forest);
forest.Union(3, 3); // no effect
ExpectRoots({2, 2, 2, 3, 3}, &forest);
forest.Union(4, 0); // 2 is the root
ExpectRoots({2, 2, 2, 2, 2}, &forest);
}
} // namespace
} // namespace dragnn
} // namespace syntaxnet
This diff is collapsed.
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/mst_solver.h"
#include <time.h>
#include <random>
#include <set>
#include <vector>
#include "dragnn/mst/spanning_tree_iterator.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace {
using ::testing::Contains;
// Returns the random seed, or 0 for a weak random seed.
int64 GetSeed() {
return 1; // use a deterministic seed
}
// Returns the number of trials to run for each random comparison.
int64 GetNumTrials() {
return 3;
}
// Testing rig. Runs a comparison between a brute-force MST solver and the
// MstSolver<> on random digraphs. When the first test parameter is true,
// solves for forests instead of trees. The second test parameter defines the
// size of the test digraph.
class MstSolverRandomComparisonTest
: public ::testing::TestWithParam<::testing::tuple<bool, uint32>> {
protected:
// Use integer scores so score comparisons are exact.
using Solver = MstSolver<uint32, int32>;
// An array providing a source node for each node. Roots are self-loops.
using SourceList = SpanningTreeIterator::SourceList;
// A row-major n x n matrix whose i,j entry gives the score of the arc from i
// to j, and whose i,i entry gives the score of selecting i as a root.
using ScoreMatrix = std::vector<int32>;
// Returns true if this should be a forest.
bool forest() const { return ::testing::get<0>(GetParam()); }
// Returns the number of nodes for digraphs.
uint32 num_nodes() const { return ::testing::get<1>(GetParam()); }
// Returns the score of the arcs in |sources| based on the |scores|.
int32 ScoreArcs(const ScoreMatrix &scores, const SourceList &sources) const {
CHECK_EQ(num_nodes() * num_nodes(), scores.size());
int32 score = 0;
for (uint32 target = 0; target < num_nodes(); ++target) {
const uint32 source = sources[target];
score += scores[target + source * num_nodes()];
}
return score;
}
// Returns the score of the maximum spanning tree (or forest, if the first
// test parameter is true) of the dense digraph defined by the |scores|, and
// sets |argmax_trees| to contain all maximal trees.
int32 RunBruteForceMstSolver(const ScoreMatrix &scores,
std::set<SourceList> *argmax_trees) {
CHECK_EQ(num_nodes() * num_nodes(), scores.size());
int32 max_score;
argmax_trees->clear();
iterator_.ForEachTree(num_nodes(), [&](const SourceList &sources) {
const int32 score = ScoreArcs(scores, sources);
if (argmax_trees->empty() || max_score < score) {
max_score = score;
argmax_trees->clear();
argmax_trees->insert(sources);
} else if (max_score == score) {
argmax_trees->insert(sources);
}
});
return max_score;
}
// As above, but uses the |solver_| and extracts only one |argmax_tree|.
int32 RunMstSolver(const ScoreMatrix &scores, SourceList *argmax_tree) {
CHECK_EQ(num_nodes() * num_nodes(), scores.size());
TF_CHECK_OK(solver_.Init(forest(), num_nodes()));
// Add all roots and arcs.
for (uint32 source = 0; source < num_nodes(); ++source) {
for (uint32 target = 0; target < num_nodes(); ++target) {
const int32 score = scores[target + source * num_nodes()];
if (source == target) {
solver_.AddRoot(target, score);
} else {
solver_.AddArc(source, target, score);
}
}
}
// Solve for the max spanning tree.
argmax_tree->resize(num_nodes());
TF_CHECK_OK(solver_.Solve(argmax_tree));
return ScoreArcs(scores, *argmax_tree);
}
// Returns a random ScoreMatrix spanning num_nodes() nodes.
ScoreMatrix RandomScores() {
ScoreMatrix scores(num_nodes() * num_nodes());
for (int32 &value : scores) value = static_cast<int32>(prng_() % 201) - 100;
return scores;
}
// Runs a comparison between MstSolver and BruteForceMst on random digraphs of
// num_nodes() nodes, for the specified number of trials.
void RunComparison() {
// Seed the PRNG, possibly non-deterministically. Log the seed value so the
// test results can be reproduced, even when the seed is non-deterministic.
uint32 seed = GetSeed();
if (seed == 0) seed = time(nullptr);
prng_.seed(seed);
LOG(INFO) << "seed = " << seed;
const int num_trials = GetNumTrials();
for (int trial = 0; trial < num_trials; ++trial) {
const ScoreMatrix scores = RandomScores();
std::set<SourceList> expected_argmax_trees;
const int32 expected_max_score =
RunBruteForceMstSolver(scores, &expected_argmax_trees);
SourceList actual_argmax_tree;
const int32 actual_max_score = RunMstSolver(scores, &actual_argmax_tree);
// In case of ties, MstSolver will find a maximal spanning tree, but we
// don't know which one.
EXPECT_EQ(expected_max_score, actual_max_score);
ASSERT_THAT(expected_argmax_trees, Contains(actual_argmax_tree));
}
}
// Tree iterator for brute-force solver.
SpanningTreeIterator iterator_{forest()};
// MstSolver<> instance used by the test. Reused across all MST invocations
// to exercise reuse.
Solver solver_;
// Pseudo-random number generator.
std::mt19937 prng_;
};
INSTANTIATE_TEST_CASE_P(AllowForest, MstSolverRandomComparisonTest,
::testing::Combine(::testing::Bool(),
::testing::Range<uint32>(1, 9)));
TEST_P(MstSolverRandomComparisonTest, Comparison) { RunComparison(); }
} // namespace
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/mst_solver.h"
#include <limits>
#include <utility>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace {
using ::testing::HasSubstr;
// Testing rig.
//
// Template args:
// Solver: An instantiation of the MstSolver<> template.
template<class Solver>
class MstSolverTest : public ::testing::Test {
protected:
using Index = typename Solver::IndexType;
using Score = typename Solver::ScoreType;
// Adds directed arcs for all |num_nodes| nodes to the |solver_| with the
// |score|.
void AddAllArcs(Index num_nodes, Score score) {
for (Index source = 0; source < num_nodes; ++source) {
for (Index target = 0; target < num_nodes; ++target) {
if (source == target) continue;
solver_.AddArc(source, target, score);
}
}
}
// Adds root selections for all |num_nodes| nodes to the |solver_| with the
// |score|.
void AddAllRoots(Index num_nodes, Score score) {
for (Index root = 0; root < num_nodes; ++root) {
solver_.AddRoot(root, score);
}
}
// Runs the |solver_| using an argmax array of size |argmax_array_size| and
// expects it to fail with an error message that matches |error_substr|.
void SolveAndExpectError(int argmax_array_size,
const string &error_message_substr) {
std::vector<Index> argmax(argmax_array_size);
EXPECT_THAT(solver_.Solve(&argmax),
test::IsErrorWithSubstr(error_message_substr));
}
// As above, but expects success. Does not assert anything about the solution
// produced by the solver.
void SolveAndExpectOk(int argmax_array_size) {
std::vector<Index> argmax(argmax_array_size);
TF_EXPECT_OK(solver_.Solve(&argmax));
}
// As above, but expects the solution to be |expected_argmax| and infers the
// argmax array size.
void SolveAndExpectArgmax(const std::vector<Index> &expected_argmax) {
std::vector<Index> actual_argmax(expected_argmax.size());
TF_ASSERT_OK(solver_.Solve(&actual_argmax));
EXPECT_EQ(expected_argmax, actual_argmax);
}
// MstSolver<> instance used by the test. Reused across all MST problems in
// each test to exercise reuse.
Solver solver_;
};
using Solvers =
::testing::Types<MstSolver<uint8, int16>, MstSolver<uint16, int32>,
MstSolver<uint32, int64>, MstSolver<uint16, float>,
MstSolver<uint32, double>>;
TYPED_TEST_CASE(MstSolverTest, Solvers);
TYPED_TEST(MstSolverTest, FailIfNoNodes) {
for (const bool forest : {false, true}) {
EXPECT_THAT(this->solver_.Init(forest, 0),
test::IsErrorWithSubstr("Non-positive number of nodes"));
}
}
TYPED_TEST(MstSolverTest, FailIfTooManyNodes) {
// Set to a value that would overflow when doubled.
const auto kNumNodes =
(std::numeric_limits<typename TypeParam::IndexType>::max() / 2) + 10;
for (const bool forest : {false, true}) {
EXPECT_THAT(this->solver_.Init(forest, kNumNodes),
test::IsErrorWithSubstr("Too many nodes"));
}
}
TYPED_TEST(MstSolverTest, InfeasibleIfNoRootsNoArcs) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->SolveAndExpectError(kNumNodes, "Infeasible digraph");
}
}
TYPED_TEST(MstSolverTest, InfeasibleIfNoRootsAllArcs) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllArcs(kNumNodes, 0);
this->SolveAndExpectError(kNumNodes, "Infeasible digraph");
}
}
TYPED_TEST(MstSolverTest, FeasibleForForestOnlyIfAllRootsNoArcs) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
if (forest) {
this->SolveAndExpectOk(kNumNodes); // all roots is a valid forest
} else {
this->SolveAndExpectError(kNumNodes, "Infeasible digraph");
}
}
}
TYPED_TEST(MstSolverTest, FeasibleIfAllRootsAllArcs) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
this->SolveAndExpectOk(kNumNodes);
}
}
TYPED_TEST(MstSolverTest, FailIfArgmaxArrayTooSmall) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
this->SolveAndExpectError(kNumNodes - 1, // too small
"Argmax array too small");
}
}
TYPED_TEST(MstSolverTest, OkIfArgmaxArrayTooLarge) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
this->SolveAndExpectOk(kNumNodes + 1); // too large
}
}
TYPED_TEST(MstSolverTest, SolveForAllRootsForestOnly) {
const int kNumNodes = 10;
const bool forest = true;
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 1); // favor all root selections
this->AddAllArcs(kNumNodes, 0);
this->SolveAndExpectArgmax({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
}
TYPED_TEST(MstSolverTest, SolveForLeftToRightChain) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
for (int target = 1; target < kNumNodes; ++target) {
this->solver_.AddArc(target - 1, target, 1); // favor left-to-right chain
}
this->SolveAndExpectArgmax({0, 0, 1, 2, 3, 4, 5, 6, 7, 8});
}
}
TYPED_TEST(MstSolverTest, SolveForRightToLeftChain) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
for (int source = 1; source < kNumNodes; ++source) {
this->solver_.AddArc(source, source - 1, 1); // favor right-to-left chain
}
this->SolveAndExpectArgmax({1, 2, 3, 4, 5, 6, 7, 8, 9, 9});
}
}
TYPED_TEST(MstSolverTest, SolveForAllFromFirstTree) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
for (int target = 1; target < kNumNodes; ++target) {
this->solver_.AddArc(0, target, 1); // favor first -> target
}
this->SolveAndExpectArgmax({0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
}
}
TYPED_TEST(MstSolverTest, SolveForAllFromLastTree) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
for (int target = 0; target + 1 < kNumNodes; ++target) {
this->solver_.AddArc(9, target, 1); // favor last -> target
}
this->SolveAndExpectArgmax({9, 9, 9, 9, 9, 9, 9, 9, 9, 9});
}
}
TYPED_TEST(MstSolverTest, SolveForBinaryTree) {
const int kNumNodes = 15;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
for (int target = 1; target < kNumNodes; ++target) {
this->solver_.AddArc((target - 1) / 2, target, 1); // like a binary heap
}
this->SolveAndExpectArgmax({0,
0, 0,
1, 1, 2, 2,
3, 3, 4, 4, 5, 5, 6, 6});
}
}
} // namespace
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <cmath>
#include <limits>
#include <type_traits>
#include <vector>
#include "dragnn/mst/mst_solver.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
namespace syntaxnet {
namespace dragnn {
// Op kernel implementation that wraps the |MstSolver|.
template <class Index, class Score>
class MaximumSpanningTreeOpKernel : public tensorflow::OpKernel {
public:
explicit MaximumSpanningTreeOpKernel(
tensorflow::OpKernelConstruction *context)
: tensorflow::OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("forest", &forest_));
}
void Compute(tensorflow::OpKernelContext *context) override {
const tensorflow::Tensor &num_nodes_tensor = context->input(0);
const tensorflow::Tensor &scores_tensor = context->input(1);
// Check ranks.
OP_REQUIRES(context, num_nodes_tensor.dims() == 1,
tensorflow::errors::InvalidArgument(
"num_nodes must be a vector, got shape ",
num_nodes_tensor.shape().DebugString()));
OP_REQUIRES(context, scores_tensor.dims() == 3,
tensorflow::errors::InvalidArgument(
"scores must be rank 3, got shape ",
scores_tensor.shape().DebugString()));
// Batch size and input dimension (B and M in the op docstring).
const int64 batch_size = scores_tensor.shape().dim_size(0);
const int64 input_dim = scores_tensor.shape().dim_size(1);
// Check shapes.
const tensorflow::TensorShape shape_b({batch_size});
const tensorflow::TensorShape shape_bxm({batch_size, input_dim});
const tensorflow::TensorShape shape_bxmxm(
{batch_size, input_dim, input_dim});
OP_REQUIRES(
context, num_nodes_tensor.shape() == shape_b,
tensorflow::errors::InvalidArgument(
"num_nodes misshapen: got ", num_nodes_tensor.shape().DebugString(),
" but expected ", shape_b.DebugString()));
OP_REQUIRES(
context, scores_tensor.shape() == shape_bxmxm,
tensorflow::errors::InvalidArgument(
"scores misshapen: got ", scores_tensor.shape().DebugString(),
" but expected ", shape_bxmxm.DebugString()));
// Create outputs.
tensorflow::Tensor *max_scores_tensor = nullptr;
tensorflow::Tensor *argmax_sources_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, shape_b, &max_scores_tensor));
OP_REQUIRES_OK(context, context->allocate_output(1, shape_bxm,
&argmax_sources_tensor));
// Acquire shaped and typed references.
const BatchedSizes num_nodes_b = num_nodes_tensor.vec<int32>();
const BatchedScores scores_bxmxm = scores_tensor.tensor<Score, 3>();
BatchedMaxima max_scores_b = max_scores_tensor->vec<Score>();
BatchedSources argmax_sources_bxm = argmax_sources_tensor->matrix<int32>();
// Solve the batch of MST problems in parallel. Set a high cycles per unit
// to encourage finer sharding.
constexpr int64 kCyclesPerUnit = 1000 * 1000 * 1000;
std::vector<tensorflow::Status> statuses(batch_size);
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
batch_size, kCyclesPerUnit, [&](int64 begin, int64 end) {
for (int64 problem = begin; problem < end; ++problem) {
statuses[problem] = RunSolver(problem, num_nodes_b, scores_bxmxm,
max_scores_b, argmax_sources_bxm);
}
});
for (const tensorflow::Status &status : statuses) {
OP_REQUIRES_OK(context, status);
}
}
private:
using BatchedSizes = typename tensorflow::TTypes<int32>::ConstVec;
using BatchedScores = typename tensorflow::TTypes<Score, 3>::ConstTensor;
using BatchedMaxima = typename tensorflow::TTypes<Score>::Vec;
using BatchedSources = typename tensorflow::TTypes<int32>::Matrix;
// Solves for the maximum spanning tree of the digraph defined by the values
// at index |problem| in |num_nodes_b| and |scores_bxmxm|. On success, sets
// the values at index |problem| in |max_scores_b| and |argmax_sources_bxm|.
// On error, returns non-OK.
tensorflow::Status RunSolver(int problem, BatchedSizes num_nodes_b,
BatchedScores scores_bxmxm,
BatchedMaxima max_scores_b,
BatchedSources argmax_sources_bxm) const {
// Check digraph size overflow.
const int32 num_nodes = num_nodes_b(problem);
const int32 input_dim = argmax_sources_bxm.dimension(1);
if (num_nodes > input_dim) {
return tensorflow::errors::InvalidArgument(
"number of nodes in digraph ", problem,
" overflows input dimension: got ", num_nodes,
" but expected <= ", input_dim);
}
if (num_nodes >= std::numeric_limits<Index>::max()) {
return tensorflow::errors::InvalidArgument(
"number of nodes in digraph ", problem, " overflows index type: got ",
num_nodes, " but expected < ", std::numeric_limits<Index>::max());
}
const Index num_nodes_index = static_cast<Index>(num_nodes);
MstSolver<Index, Score> solver;
TF_RETURN_IF_ERROR(solver.Init(forest_, num_nodes_index));
// Populate the solver with arcs and root selections. Note that non-finite
// scores are treated as nonexistent arcs or roots.
for (Index target = 0; target < num_nodes_index; ++target) {
for (Index source = 0; source < num_nodes_index; ++source) {
const Score score = scores_bxmxm(problem, target, source);
if (!std::isfinite(score)) continue;
if (source == target) { // root
solver.AddRoot(target, score);
} else { // arc
solver.AddArc(source, target, score);
}
}
}
std::vector<Index> argmax(num_nodes);
TF_RETURN_IF_ERROR(solver.Solve(&argmax));
// Output the tree and accumulate its score.
Score max_score = 0;
for (Index target = 0; target < num_nodes_index; ++target) {
const Index source = argmax[target];
argmax_sources_bxm(problem, target) = source;
max_score += scores_bxmxm(problem, target, source);
}
max_scores_b(problem) = max_score;
// Pad the source list with -1.
for (int32 i = num_nodes; i < input_dim; ++i) {
argmax_sources_bxm(problem, i) = -1;
}
return tensorflow::Status::OK();
}
private:
bool forest_ = false;
};
// Use Index=uint16, which allows digraphs containing up to 32,767 nodes.
REGISTER_KERNEL_BUILDER(Name("MaximumSpanningTree")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<int32>("T"),
MaximumSpanningTreeOpKernel<uint16, int32>);
REGISTER_KERNEL_BUILDER(Name("MaximumSpanningTree")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<float>("T"),
MaximumSpanningTreeOpKernel<uint16, float>);
REGISTER_KERNEL_BUILDER(Name("MaximumSpanningTree")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<double>("T"),
MaximumSpanningTreeOpKernel<uint16, double>);
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace syntaxnet {
namespace dragnn {
REGISTER_OP("MaximumSpanningTree")
.Attr("T: {int32, float, double}")
.Attr("forest: bool = false")
.Input("num_nodes: int32")
.Input("scores: T")
.Output("max_scores: T")
.Output("argmax_sources: int32")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
tensorflow::shape_inference::ShapeHandle num_nodes;
tensorflow::shape_inference::ShapeHandle scores;
TF_RETURN_IF_ERROR(context->WithRank(context->input(0), 1, &num_nodes));
TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 3, &scores));
// Extract dimensions while asserting that they match.
tensorflow::shape_inference::DimensionHandle batch_size; // aka "B"
TF_RETURN_IF_ERROR(context->Merge(context->Dim(num_nodes, 0),
context->Dim(scores, 0), &batch_size));
tensorflow::shape_inference::DimensionHandle max_nodes; // aka "M"
TF_RETURN_IF_ERROR(context->Merge(context->Dim(scores, 1),
context->Dim(scores, 2), &max_nodes));
context->set_output(0, context->Vector(batch_size));
context->set_output(1, context->Matrix(batch_size, max_nodes));
return tensorflow::Status::OK();
})
.Doc(R"doc(
Finds the maximum directed spanning tree of a digraph.
Given a batch of digraphs with scored arcs and root selections, solves for the
maximum spanning tree of each digraph, where the score of a tree is defined as
the sum of the scores of the arcs and roots making up the tree.
Returns the score of the maximum spanning tree of each digraph, as well as the
arcs and roots in that tree. Each digraph in a batch may contain a different
number of nodes, so the sizes of the digraphs must be provided as an input.
Note that this operation is only differentiable w.r.t. its |scores| input and
its |max_scores| output.
forest: If true, solves for a maximum spanning forest instead of a maximum
spanning tree, where a spanning forest is a set of disjoint trees that
span the nodes of the digraph.
num_nodes: [B] vector where entry b is number of nodes in the b'th digraph.
scores: [B,M,M] tensor where entry b,t,s is the score of the arc from s to t in
the b'th digraph, if s!=t, or the score of selecting t as a root in the
b'th digraph, if s==t. Requires that M is >= num_nodes[b], for all b,
and ignores entries b,s,t where s or t is >= num_nodes[b]. Arcs or root
selections with non-finite score are treated as nonexistent.
max_scores: [B] vector where entry b is the score of the maximum spanning tree
of the b'th digraph.
argmax_sources: [B,M] matrix where entry b,t is the source of the arc inbound to
t in the maximum spanning tree of the b'th digraph, or t if t is
a root. Entries b,t where t is >= num_nodes[b] are set to -1.
)doc");
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/spanning_tree_iterator.h"
namespace syntaxnet {
namespace dragnn {
SpanningTreeIterator::SpanningTreeIterator(bool forest) : forest_(forest) {}
bool SpanningTreeIterator::HasCycle(const SourceList &sources) {
// Flags for whether each node has already been searched.
searched_.assign(sources.size(), false);
// Flags for whether the search is currently visiting each node.
visiting_.assign(sources.size(), false);
// Search upwards from each node to find cycles.
for (uint32 initial_node = 0; initial_node < sources.size(); ++initial_node) {
// Search upwards to try to find a cycle.
uint32 current_node = initial_node;
while (true) {
if (searched_[current_node]) break; // already searched
if (visiting_[current_node]) return true; // revisiting implies cycle
visiting_[current_node] = true; // mark as being currently visited
const uint32 source_node = sources[current_node];
if (source_node == current_node) break; // self-loops are roots
current_node = source_node; // advance upwards
}
// No cycle; search upwards again to update flags.
current_node = initial_node;
while (true) {
if (searched_[current_node]) break; // already searched
searched_[current_node] = true;
visiting_[current_node] = false;
const uint32 source_node = sources[current_node];
if (source_node == current_node) break; // self-loops are roots
current_node = source_node; // advance upwards
}
}
return false;
}
uint32 SpanningTreeIterator::NumRoots(const SourceList &sources) {
uint32 num_roots = 0;
for (uint32 node = 0; node < sources.size(); ++node) {
num_roots += (node == sources[node]);
}
return num_roots;
}
bool SpanningTreeIterator::NextSourceList(SourceList *sources) {
const uint32 num_nodes = sources->size();
for (uint32 i = 0; i < num_nodes; ++i) {
const uint32 new_source = ++(*sources)[i];
if (new_source < num_nodes) return true; // absorbed in this digit
(*sources)[i] = 0; // overflowed this digit, carry to next digit
}
return false; // overflowed the last digit
}
bool SpanningTreeIterator::NextTree(SourceList *sources) {
// Iterate source lists, skipping non-trees.
while (NextSourceList(sources)) {
// Check the number of roots.
const uint32 num_roots = NumRoots(*sources);
if (forest_) {
if (num_roots == 0) continue;
} else {
if (num_roots != 1) continue;
}
// Check for cycles.
if (HasCycle(*sources)) continue;
// Acyclic and rooted, therefore tree.
return true;
}
return false;
}
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_MST_SPANNING_TREE_ITERATOR_H_
#define DRAGNN_MST_SPANNING_TREE_ITERATOR_H_
#include <vector>
#include "syntaxnet/base.h"
namespace syntaxnet {
namespace dragnn {
// A class that iterates over all possible spanning trees of a complete digraph.
// Thread-compatible. Useful for brute-force comparison tests.
//
// TODO(googleuser): Try using Prufer sequences, which are more efficient to
// enumerate as there are no non-trees to filter out.
class SpanningTreeIterator {
public:
// An array that provides the source of the inbound arc for each node. Roots
// are represented as self-loops.
using SourceList = std::vector<uint32>;
// Creates a spanning tree iterator. If |forest| is true, then this iterates
// over forests instead of trees (i.e., multiple roots are allowed).
explicit SpanningTreeIterator(bool forest);
// Applies the |functor| to all spanning trees (or forests, if |forest_| is
// true) of a complete digraph containing |num_nodes| nodes. Each tree is
// passed to the |functor| as a SourceList.
template <class Functor>
void ForEachTree(uint32 num_nodes, Functor functor) {
// Conveniently, the all-zero vector represents a valid tree.
SourceList sources(num_nodes, 0);
do {
functor(sources);
} while (NextTree(&sources));
}
private:
// Returns true if the |sources| contains a cycle.
bool HasCycle(const SourceList &sources);
// Returns the number of roots in the |sources|.
static uint32 NumRoots(const SourceList &sources);
// Advances |sources| to the next source list, or returns false if there are
// no more source lists.
static bool NextSourceList(SourceList *sources);
// Advances |sources| to the next tree (or forest, if |forest_| is true), or
// returns false if there are no more trees.
bool NextTree(SourceList *sources);
// If true, iterate over spanning forests instead of spanning trees.
const bool forest_;
// Workspaces used by the search in HasCycle().
std::vector<bool> searched_;
std::vector<bool> visiting_;
};
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_MST_SPANNING_TREE_ITERATOR_H_
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/spanning_tree_iterator.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace {
// Testing rig. When the bool parameter is true, iterates over spanning forests
// instead of spanning trees.
class SpanningTreeIteratorTest : public ::testing::TestWithParam<bool> {
protected:
using SourceList = SpanningTreeIterator::SourceList;
// Returns |base|^|exponent|. Computes the value as an integer to avoid
// rounding issues.
static int Pow(int base, int exponent) {
double real_product = 1.0;
int product = 1;
for (int i = 0; i < exponent; ++i) {
product *= base;
real_product *= base;
}
CHECK_EQ(product, real_product) << "Overflow detected.";
return product;
}
// Expects that the number of possible spanning trees for a complete digraph
// of |num_nodes| nodes is |expected_num_trees|.
void ExpectNumTrees(int num_nodes, int expected_num_trees) {
int actual_num_trees = 0;
iterator_.ForEachTree(
num_nodes, [&](const SourceList &sources) { ++actual_num_trees; });
LOG(INFO) << "num_nodes=" << num_nodes
<< " expected_num_trees=" << expected_num_trees
<< " actual_num_trees=" << actual_num_trees;
EXPECT_EQ(expected_num_trees, actual_num_trees);
}
// Expects that the set of possible spanning trees for a complete digraph of
// |num_nodes| nodes is |expected_trees|.
void ExpectTrees(int num_nodes, const std::set<SourceList> &expected_trees) {
std::set<SourceList> actual_trees;
iterator_.ForEachTree(num_nodes, [&](const SourceList &sources) {
CHECK(actual_trees.insert(sources).second);
});
EXPECT_EQ(expected_trees, actual_trees);
}
// Instance for tests. Shared across assertions in a test to exercise reuse.
SpanningTreeIterator iterator_{GetParam()};
};
INSTANTIATE_TEST_CASE_P(AllowForest, SpanningTreeIteratorTest,
::testing::Bool());
TEST_P(SpanningTreeIteratorTest, NumberOfTrees) {
// According to Cayley's formula, the number of undirected spanning trees on a
// complete graph of n nodes is n^{n-2}:
// https://en.wikipedia.org/wiki/Cayley%27s_formula
//
// To count the number of directed spanning trees, note that each undirected
// spanning tree gives rise to n directed spanning trees: choose one of the n
// nodes as the root, and then orient arcs outwards. Therefore, the number of
// directed spanning trees on a complete digraph of n nodes is n^{n-1}.
//
// To count the number of directed spanning forests, consider undirected
// spanning trees on a complete graph of n+1 nodes. Arbitrarily select one
// node as the artificial root, orient arcs outwards, and then delete the
// artificial root and its outbound arcs. The result is a directed spanning
// forest on n nodes. Therefore, the number of directed spanning forests on a
// complete digraph of n nodes is (n+1)^{n-1}.
for (int num_nodes = 1; num_nodes <= 7; ++num_nodes) {
if (GetParam()) { // forest
ExpectNumTrees(num_nodes, Pow(num_nodes + 1, num_nodes - 1));
} else { // tree
ExpectNumTrees(num_nodes, Pow(num_nodes, num_nodes - 1));
}
}
}
TEST_P(SpanningTreeIteratorTest, OneNodeDigraph) {
ExpectTrees(1, {{0}});
}
TEST_P(SpanningTreeIteratorTest, TwoNodeDigraph) {
if (GetParam()) { // forest
ExpectTrees(2, {{0, 0}, {0, 1}, {1, 1}}); // {0, 1} is two-root structure
} else { // tree
ExpectTrees(2, {{0, 0}, {1, 1}});
}
}
TEST_P(SpanningTreeIteratorTest, ThreeNodeDigraph) {
if (GetParam()) { // forest
ExpectTrees(3, {{0, 0, 0},
{0, 0, 1},
{0, 0, 2}, // 2-root
{0, 1, 0}, // 2-root
{0, 1, 1}, // 2-root
{0, 1, 2}, // 3-root
{0, 2, 0},
{0, 2, 2}, // 2-root
{1, 1, 0},
{1, 1, 1},
{1, 1, 2}, // 2-root
{1, 2, 2},
{2, 0, 2},
{2, 1, 1},
{2, 1, 2}, // 2-root
{2, 2, 2}});
} else { // tree
ExpectTrees(3, {{0, 0, 0},
{0, 0, 1},
{0, 2, 0},
{1, 1, 0},
{1, 1, 1},
{1, 2, 2},
{2, 0, 2},
{2, 1, 1},
{2, 2, 2}});
}
}
} // namespace
} // namespace dragnn
} // namespace syntaxnet
......@@ -2,48 +2,63 @@ package(default_visibility = ["//visibility:public"])
load(
"//syntaxnet:syntaxnet.bzl",
"tf_proto_library",
"tf_proto_library_cc",
"tf_proto_library_py",
)
# Protos.
tf_proto_library(
tf_proto_library_cc(
name = "data_proto",
srcs = ["data.proto"],
)
tf_proto_library(
tf_proto_library_cc(
name = "trace_proto",
srcs = ["trace.proto"],
deps = [
":data_proto",
],
protodeps = [":data_proto"],
)
tf_proto_library(
tf_proto_library_cc(
name = "cell_trace_proto",
srcs = ["cell_trace.proto"],
protodeps = [":trace_proto"],
)
tf_proto_library_cc(
name = "spec_proto",
srcs = ["spec.proto"],
)
tf_proto_library(
tf_proto_library_cc(
name = "runtime_proto",
srcs = ["runtime.proto"],
deps = [":spec_proto"],
protodeps = [":spec_proto"],
)
tf_proto_library_cc(
name = "export_proto",
srcs = ["export.proto"],
protodeps = [":spec_proto"],
)
tf_proto_library_py(
name = "data_py_pb2",
name = "data_pb2",
srcs = ["data.proto"],
)
tf_proto_library_py(
name = "trace_py_pb2",
name = "trace_pb2",
srcs = ["trace.proto"],
deps = [":data_py_pb2"],
protodeps = [":data_pb2"],
)
tf_proto_library_py(
name = "spec_py_pb2",
name = "spec_pb2",
srcs = ["spec.proto"],
)
tf_proto_library_py(
name = "export_pb2",
srcs = ["export.proto"],
)
This diff is collapsed.
// DRAGNN data proto. See go/dragnn-design for more information.
// DRAGNN data proto.
syntax = "proto2";
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment