Unverified Commit 80178fc6 authored by Mark Omernick's avatar Mark Omernick Committed by GitHub
Browse files

Merge pull request #4153 from terryykoo/master

Export @195097388.
parents a84e1ef9 edea2b67
......@@ -62,13 +62,14 @@ class MockComputeSession : public ComputeSession {
MOCK_METHOD2(GetTranslatedLinkFeatures,
std::vector<LinkFeatures>(const string &component_name,
int channel_id));
MOCK_METHOD1(EmitOracleLabels,
std::vector<std::vector<int>>(const string &component_name));
MOCK_METHOD1(EmitOracleLabels, std::vector<std::vector<std::vector<Label>>>(
const string &component_name));
MOCK_METHOD1(IsTerminal, bool(const string &component_name));
MOCK_METHOD1(FinalizeData, void(const string &component_name));
MOCK_METHOD0(GetSerializedPredictions, std::vector<string>());
MOCK_METHOD0(GetTraceProtos, std::vector<MasterTrace>());
MOCK_METHOD1(SetInputData, void(const std::vector<string> &data));
MOCK_METHOD0(GetInputBatchCache, InputBatchCache *());
MOCK_METHOD0(ResetSession, void());
MOCK_METHOD1(SetTracing, void(bool tracing_on));
MOCK_CONST_METHOD0(Id, int());
......
package(
default_visibility = ["//visibility:public"],
features = ["-layering_check"],
)
cc_library(
name = "label",
hdrs = ["label.h"],
)
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_CORE_UTIL_LABEL_H_
#define DRAGNN_CORE_UTIL_LABEL_H_
#include <cmath>
namespace syntaxnet {
namespace dragnn {
// Stores label information.
struct Label {
Label(int label_id, float label_probability)
: id(label_id), probability(label_probability) {}
explicit Label(int label_id) : id(label_id) {}
// Two Labels are equal if the ids match and the probabilities are within an
// epsilon of one another.
bool operator==(const Label &label) const {
return (id == label.id) &&
std::fabs(probability - label.probability) < 0.00001;
}
// Label id and probability.
int id;
float probability = 1.0;
};
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_CORE_UTIL_LABEL_H_
......@@ -8,7 +8,7 @@ cc_library(
":syntaxnet_sentence",
"//dragnn/core/interfaces:input_batch",
"//syntaxnet:base",
"//syntaxnet:sentence_proto",
"//syntaxnet:sentence_proto_cc",
],
)
......@@ -16,7 +16,7 @@ cc_library(
name = "syntaxnet_sentence",
hdrs = ["syntaxnet_sentence.h"],
deps = [
"//syntaxnet:sentence_proto",
"//syntaxnet:sentence_proto_cc",
"//syntaxnet:workspace",
],
)
......@@ -27,7 +27,7 @@ cc_test(
deps = [
":sentence_input_batch",
"//dragnn/core/test:generic",
"//syntaxnet:sentence_proto",
"//syntaxnet:sentence_proto_cc",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:test",
],
......
package(default_visibility = ["//visibility:public"])
cc_library(
name = "disjoint_set_forest",
hdrs = ["disjoint_set_forest.h"],
deps = [
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "disjoint_set_forest_test",
size = "small",
srcs = ["disjoint_set_forest_test.cc"],
deps = [
":disjoint_set_forest",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "spanning_tree_iterator",
testonly = 1,
srcs = ["spanning_tree_iterator.cc"],
hdrs = ["spanning_tree_iterator.h"],
deps = [
"//syntaxnet:base",
],
)
cc_test(
name = "spanning_tree_iterator_test",
size = "small",
srcs = ["spanning_tree_iterator_test.cc"],
deps = [
":spanning_tree_iterator",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "mst_solver",
hdrs = ["mst_solver.h"],
deps = [
":disjoint_set_forest",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "mst_solver_test",
size = "small",
srcs = ["mst_solver_test.cc"],
deps = [
":mst_solver",
"//dragnn/core/test:generic",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_test(
name = "mst_solver_random_comparison_test",
size = "small",
timeout = "long",
srcs = ["mst_solver_random_comparison_test.cc"],
tags = [
"manual", # exclude from :all, since this is expensive
],
deps = [
":mst_solver",
":spanning_tree_iterator",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
load(
"@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
)
tf_gen_op_libs(
op_lib_names = ["mst_ops"],
)
# Don't use this library directly; instead use "dragnn/python:mst_ops".
tf_gen_op_wrapper_py(
name = "mst_ops",
visibility = ["//dragnn/python:__pkg__"],
deps = [":mst_ops_op_lib"],
)
cc_library(
name = "mst_ops_cc",
srcs = [
"ops/mst_op_kernels.cc",
"ops/mst_ops.cc",
],
deps = [
":mst_solver",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
Package for solving max-spanning-tree (MST) problems. The code here is intended
for NLP applications, but attempts to remain agnostic to particular NLP tasks
(such as dependency parsing).
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_MST_DISJOINT_SET_FOREST_H_
#define DRAGNN_MST_DISJOINT_SET_FOREST_H_
#include <stddef.h>
#include <type_traits>
#include <vector>
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
// An implementation of the disjoint-set forest data structure. The universe of
// elements is the dense range of indices [0,n). Thread-compatible.
//
// By default, this uses the path compression and union by rank optimizations,
// achieving near-constant runtime on all operations. However, the user may
// disable the union by rank optimization, which allows the user to control how
// roots are selected when a union occurs. When union by rank is disabled, the
// runtime of all operations increases to O(log n) amortized.
//
// Template args:
// Index: An unsigned integral type wide enough to hold n.
// kUseUnionByRank: Whether to use the union by rank optimization.
template <class Index, bool kUseUnionByRank = true>
class DisjointSetForest {
public:
static_assert(std::is_integral<Index>::value, "Index must be integral");
static_assert(!std::is_signed<Index>::value, "Index must be unsigned");
using IndexType = Index;
// Creates an empty forest.
DisjointSetForest() = default;
// Initializes this to hold the elements [0,|size|), each initially in its own
// singleton set. Replaces existing state, if any.
void Init(Index size);
// Returns the root of the set containing |element|, which uniquely identifies
// the set. Note that the root of a set may change as the set is merged with
// other sets; do not cache the return value of FindRoot(e) across calls to
// Union() or UnionOfRoots() that could merge the set containing e.
Index FindRoot(Index element);
// For convenience, returns true if |element1| and |element2| are in the same
// set. When performing a large batch of queries it may be more efficient to
// cache the value of FindRoot(), modulo caveats regarding caching above.
bool SameSet(Index element1, Index element2);
// Merges the sets rooted at |root1| and |root2|, which must be the roots of
// their respective sets. Either |root1| or |root2| will be the root of the
// merged set. If |kUseUnionByRank| is true, then it is unspecified whether
// |root1| or |root2| will be the root; otherwise, |root2| will be the root.
void UnionOfRoots(Index root1, Index root2);
// As above, but for convenience finds the root of |element1| and |element2|.
void Union(Index element1, Index element2);
// The number of elements in this.
Index size() const { return size_; }
private:
// The number of elements in the universe underlying the sets.
Index size_ = 0;
// The parent of each element, where self-loops are roots.
std::vector<Index> parents_;
// The rank of each element, for the union by rank optimization. Only used if
// |kUseUnionByRank| is true.
std::vector<Index> ranks_;
};
// Implementation details below.
template <class Index, bool kUseUnionByRank>
void DisjointSetForest<Index, kUseUnionByRank>::Init(Index size) {
size_ = size;
parents_.resize(size_);
if (kUseUnionByRank) ranks_.resize(size_);
// Create singleton sets.
for (Index i = 0; i < size_; ++i) {
parents_[i] = i;
if (kUseUnionByRank) ranks_[i] = 0;
}
}
template <class Index, bool kUseUnionByRank>
Index DisjointSetForest<Index, kUseUnionByRank>::FindRoot(Index element) {
DCHECK_LT(element, size());
Index *const __restrict parents = parents_.data();
// Walk up to the root of the |element|. Unroll the first two comparisons
// because path compression ensures most FindRoot() calls end there. In
// addition, if a root is found within the first two comparisons, then the
// path compression updates can be skipped.
Index current = element;
Index parent = parents[current];
if (current == parent) return current; // |element| is a root
current = parent;
parent = parents[current];
if (current == parent) return current; // |element| is the child of a root
do { // otherwise, continue upwards until root
current = parent;
parent = parents[current];
} while (current != parent);
const Index root = current;
// Apply path compression on the traversed nodes.
current = element;
parent = parents[current]; // not root, thanks to unrolling above
do {
parents[current] = root;
current = parent;
parent = parents[current];
} while (parent != root);
return root;
}
template <class Index, bool kUseUnionByRank>
bool DisjointSetForest<Index, kUseUnionByRank>::SameSet(Index element1,
Index element2) {
return FindRoot(element1) == FindRoot(element2);
}
template <class Index, bool kUseUnionByRank>
void DisjointSetForest<Index, kUseUnionByRank>::UnionOfRoots(Index root1,
Index root2) {
DCHECK_LT(root1, size());
DCHECK_LT(root2, size());
DCHECK_EQ(root1, parents_[root1]);
DCHECK_EQ(root2, parents_[root2]);
if (root1 == root2) return; // already merged
Index *const __restrict parents = parents_.data();
if (kUseUnionByRank) {
// Attach the lesser-rank root to the higher-rank root.
Index *const __restrict ranks = ranks_.data();
const Index rank1 = ranks[root1];
const Index rank2 = ranks[root2];
if (rank2 < rank1) {
parents[root2] = root1;
} else if (rank1 < rank2) {
parents[root1] = root2;
} else {
// Equal ranks; choose one arbitrarily and promote its rank.
parents[root1] = root2;
ranks[root2] = rank2 + 1;
}
} else {
// Always make |root2| the root of the merged set.
parents[root1] = root2;
}
}
template <class Index, bool kUseUnionByRank>
void DisjointSetForest<Index, kUseUnionByRank>::Union(Index element1,
Index element2) {
UnionOfRoots(FindRoot(element1), FindRoot(element2));
}
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_MST_DISJOINT_SET_FOREST_H_
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/disjoint_set_forest.h"
#include <stddef.h>
#include <set>
#include <utility>
#include <vector>
#include "syntaxnet/base.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace {
// Testing rig.
//
// Template args:
// Forest: An instantiation of the DisjointSetForest<> template.
template <class Forest>
class DisjointSetForestTest : public ::testing::Test {
protected:
using Index = typename Forest::IndexType;
// Expects that the |expected_sets| and |forest| match.
void ExpectSets(const std::set<std::set<Index>> &expected_sets,
Forest *forest) {
std::set<std::pair<Index, Index>> expected_pairs;
for (const auto &expected_set : expected_sets) {
for (auto it = expected_set.begin(); it != expected_set.end(); ++it) {
for (auto jt = expected_set.begin(); jt != expected_set.end(); ++jt) {
expected_pairs.emplace(*it, *jt);
}
}
}
for (Index lhs = 0; lhs < forest->size(); ++lhs) {
for (Index rhs = 0; rhs < forest->size(); ++rhs) {
if (expected_pairs.find({lhs, rhs}) != expected_pairs.end()) {
EXPECT_EQ(forest->FindRoot(lhs), forest->FindRoot(rhs));
EXPECT_TRUE(forest->SameSet(lhs, rhs));
} else {
EXPECT_NE(forest->FindRoot(lhs), forest->FindRoot(rhs));
EXPECT_FALSE(forest->SameSet(lhs, rhs));
}
}
}
}
};
using Forests = ::testing::Types<
DisjointSetForest<uint8, false>, DisjointSetForest<uint8, true>,
DisjointSetForest<uint16, false>, DisjointSetForest<uint16, true>,
DisjointSetForest<uint32, false>, DisjointSetForest<uint32, true>,
DisjointSetForest<uint64, false>, DisjointSetForest<uint64, true>>;
TYPED_TEST_CASE(DisjointSetForestTest, Forests);
TYPED_TEST(DisjointSetForestTest, DefaultEmpty) {
TypeParam forest;
EXPECT_EQ(0, forest.size());
}
TYPED_TEST(DisjointSetForestTest, InitEmpty) {
TypeParam forest;
forest.Init(0);
EXPECT_EQ(0, forest.size());
}
TYPED_TEST(DisjointSetForestTest, Populated) {
TypeParam forest;
forest.Init(5);
EXPECT_EQ(5, forest.size());
this->ExpectSets({{0}, {1}, {2}, {3}, {4}}, &forest);
forest.UnionOfRoots(1, 2);
this->ExpectSets({{0}, {1, 2}, {3}, {4}}, &forest);
forest.Union(1, 2);
this->ExpectSets({{0}, {1, 2}, {3}, {4}}, &forest);
forest.UnionOfRoots(0, 4);
this->ExpectSets({{0, 4}, {1, 2}, {3}}, &forest);
forest.Union(3, 4);
this->ExpectSets({{0, 3, 4}, {1, 2}}, &forest);
forest.Union(0, 3);
this->ExpectSets({{0, 3, 4}, {1, 2}}, &forest);
forest.Union(2, 0);
this->ExpectSets({{0, 1, 2, 3, 4}}, &forest);
forest.Union(1, 3);
this->ExpectSets({{0, 1, 2, 3, 4}}, &forest);
}
// Testing rig for checking that when union by rank is disabled, the root of a
// merged set can be controlled.
class DisjointSetForestNoUnionByRankTest : public ::testing::Test {
protected:
using Forest = DisjointSetForest<uint32, false>;
// Expects that the roots of the |forest| match |expected_roots|.
void ExpectRoots(const std::vector<uint32> &expected_roots, Forest *forest) {
ASSERT_EQ(expected_roots.size(), forest->size());
for (uint32 i = 0; i < forest->size(); ++i) {
EXPECT_EQ(expected_roots[i], forest->FindRoot(i));
}
}
};
TEST_F(DisjointSetForestNoUnionByRankTest, ManuallySpecifyRoot) {
Forest forest;
forest.Init(5);
ExpectRoots({0, 1, 2, 3, 4}, &forest);
forest.UnionOfRoots(0, 1); // 1 is the root
ExpectRoots({1, 1, 2, 3, 4}, &forest);
forest.Union(4, 3); // 3 is the root
ExpectRoots({1, 1, 2, 3, 3}, &forest);
forest.Union(0, 2); // 2 is the root
ExpectRoots({2, 2, 2, 3, 3}, &forest);
forest.Union(3, 3); // no effect
ExpectRoots({2, 2, 2, 3, 3}, &forest);
forest.Union(4, 0); // 2 is the root
ExpectRoots({2, 2, 2, 2, 2}, &forest);
}
} // namespace
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_MST_MST_SOLVER_H_
#define DRAGNN_MST_MST_SOLVER_H_
#include <stddef.h>
#include <algorithm>
#include <cmath>
#include <limits>
#include <type_traits>
#include <utility>
#include <vector>
#include "dragnn/mst/disjoint_set_forest.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace syntaxnet {
namespace dragnn {
// Maximum spanning tree solver for directed graphs. Thread-compatible.
//
// The solver operates on a digraph of n nodes and m arcs and outputs a maximum
// spanning tree rooted at any node. Scores can be associated with arcs and
// root selections, and the score of a tree is the sum of the relevant arc and
// root-selection scores.
//
// The implementation is based on:
//
// R.E. Tarjan. 1977. Finding Optimum Branchings. Networks 7(1), pp. 25-35.
// [In particular, see Section 4 "a modification for dense graphs"]
//
// which itself is an improvement of the Chu-Liu-Edmonds algorithm. Note also
// the correction in:
//
// P.M. Camerini, L. Fratta, F. Maffioli. 1979. A Note on Finding Optimum
// Branchings. Networks 9(4), pp. 309-312.
//
// The solver runs in O(n^2) time, which is optimal for dense digraphs but slow
// for sparse digraphs where O(m + n log n) can be achieved. The solver uses
// O(n^2) space to store the digraph, which is also optimal for dense digraphs.
//
// Although this algorithm has an inferior asymptotic runtime on sparse graphs,
// it avoids high-constant-overhead data structures like Fibonacci heaps, which
// are required in the asymptotically faster algorithms. Therefore, this solver
// may still be competitive on small sparse graphs.
//
// TODO(googleuser): If we start running on large sparse graphs, implement the
// following, which runs in O(m + n log n):
//
// H.N. Gabow, Z. Galil, T. Spencer, and R.E. Tarjan. 1986. Efficient
// algorithms for finding minimum spanning trees in undirected and directed
// graphs. Combinatorica, 6(2), pp. 109-122.
//
// Template args:
// Index: An unsigned integral type wide enough to hold 2n.
// Score: A signed arithmetic (integral or floating-point) type.
template <class Index, class Score>
class MstSolver {
public:
static_assert(std::is_integral<Index>::value, "Index must be integral");
static_assert(!std::is_signed<Index>::value, "Index must be unsigned");
static_assert(std::is_arithmetic<Score>::value, "Score must be arithmetic");
static_assert(std::is_signed<Score>::value, "Score must be signed");
using IndexType = Index;
using ScoreType = Score;
// Creates an empty solver. Call Init() before use.
MstSolver() = default;
// Initializes this for a digraph with |num_nodes| nodes, or returns non-OK on
// error. Discards existing state; call AddArc() and AddRoot() to add arcs
// and root selections. If |forest| is true, then this solves for a maximum
// spanning forest (i.e., a set of disjoint trees that span the digraph).
tensorflow::Status Init(bool forest, Index num_nodes);
// Adds an arc from the |source| node to the |target| node with the |score|.
// The |source| and |target| must be distinct node indices in [0,n), and the
// |score| must be finite. Calling this multiple times on the same |source|
// and |target| overwrites the score instead of adding parallel arcs.
void AddArc(Index source, Index target, Score score);
// As above, but adds a root selection for the |root| node with the |score|.
void AddRoot(Index root, Score score);
// Populates |argmax| with the maximum directed spanning tree of the current
// digraph, or returns non-OK on error. The |argmax| array must contain at
// least n elements. On success, argmax[t] is the source of the arc directed
// into t, or t itself if t is a root.
//
// NB: If multiple spanning trees achieve the maximum score, |argmax| will be
// set to one of the maximal trees, but it is unspecified which one.
tensorflow::Status Solve(tensorflow::gtl::MutableArraySlice<Index> argmax);
private:
// Implementation notes:
//
// The solver does not operate on the "original" digraph as specified by the
// user, but a "transformed" digraph that differs as follows:
//
// * The transformed digraph adds an "artificial root" node at index 0 and
// offsets all original node indices by +1 to make room. For each root
// selection, the artificial root has one outbound arc directed into the
// candidate root that carries the root-selection score. The artificial
// root has no inbound arcs.
//
// * When solving for a spanning tree (i.e., when |forest_| is false), the
// outbound arcs of the artificial root are penalized to ensure that the
// artificial root has exactly one child.
//
// In the remainder of this file, all mentions of nodes, arcs, etc., refer to
// the transformed digraph unless otherwise specified.
//
// The algorithm is divided into two phases, the "contraction phase" and the
// "expansion phase". The contraction phase finds the arcs that make up the
// maximum spanning tree by applying a series of "contractions" which further
// modify the digraph. The expansion phase "expands" these modifications and
// recovers the maximum spanning tree in the original digraph.
//
// During the contraction phase, the algorithm selects the best inbound arc
// for each node. These arcs can form cycles, which are "contracted" by
// removing the cycle nodes and replacing them with a new contracted node.
// Since each contraction removes 2 or more cycle nodes and adds 1 contracted
// node, at most n-1 contractions will occur. (The digraph initially contains
// n+1 nodes, but one is the artificial root, which cannot form a cycle).
//
// When contracting a cycle, nodes are not explicitly removed and replaced.
// Instead, a contracted node is appended to the digraph and the cycle nodes
// are remapped to the contracted node, which implicitly removes and replaces
// the cycle. As a result, each contraction actually increases the size of
// the digraph, up to a maximum of 2n nodes. One advantage of adding and
// remapping nodes is that it is convenient to recover the argmax spanning
// tree during the expansion phase.
//
// Note that contractions can be nested, because the best inbound arc for a
// contracted node may itelf form a cycle. During the expansion phase, the
// algorithm picks a root of the hierarchy of contracted nodes, breaks the
// cycle it represents, and repeats until all cycles are broken.
// Constants, as enums to avoid the need for static variable definitions.
enum Constants : Index {
// An index reserved for "null" values.
kNullIndex = std::numeric_limits<Index>::max(),
};
// A possibly-nonexistent arc in the digraph.
struct Arc {
// Creates a nonexistent arc.
Arc() = default;
// Returns true if this arc exists.
bool Exists() const { return target != 0; }
// Returns true if this is a root-selection arc.
bool IsRoot() const { return source == 0; }
// Returns a string representation of this arc.
string DebugString() const {
if (!Exists()) return "[null]";
if (IsRoot()) {
return tensorflow::strings::StrCat("[*->", target, "=", score, "]");
}
return tensorflow::strings::StrCat("[", source, "->", target, "=", score,
"]");
}
// Score of this arc.
Score score;
// Source of this arc in the initial digraph.
Index source;
// Target of this arc in the initial digraph, or 0 if this is nonexistent.
Index target = 0;
};
// Returns the index, in |arcs_|, of the arc from |source| to |target|. The
// |source| must be one of the initial n+1 nodes.
size_t ArcIndex(size_t source, size_t target) const;
// Penalizes the root arc scores to ensure that this finds a tree, or does
// nothing if |forest_| is true. Must be called before ContractionPhase().
void MaybePenalizeRootScoresForTree();
// Returns the maximum inbound arc of the |node|, or null if there is none.
const Arc *MaximumInboundArc(Index node) const;
// Merges the inbound arcs of the |cycle_node| into the inbound arcs of the
// |contracted_node|. Arcs are merged as follows:
// * If the source and target of the arc belong to the same strongly-connected
// component, it is ignored.
// * If exactly one of the nodes had an arc from some source, then on exit the
// |contracted_node| has that arc.
// * If both of the nodes had an arc from the same source, then on exit the
// |contracted_node| has the better-scoring arc.
// The |score_offset| is added to the arc scores of the |cycle_node| before
// they are merged into the |contracted_node|.
void MergeInboundArcs(Index cycle_node, Score score_offset,
Index contracted_node);
// Contracts the cycle in |argmax_arcs_| that contains the |node|.
void ContractCycle(Index node);
// Runs the contraction phase of the solver, or returns non-OK on error. This
// phase finds the best inbound arc for each node, contracting cycles as they
// are formed. Stops when every node has selected an inbound arc and there
// are no cycles.
tensorflow::Status ContractionPhase();
// Runs the expansion phase of the solver, or returns non-OK on error. This
// phase expands each contracted node, breaks cycles, and populates |argmax|
// with the maximum spanning tree.
tensorflow::Status ExpansionPhase(
tensorflow::gtl::MutableArraySlice<Index> argmax);
// If true, solve for a spanning forest instead of a spanning tree.
bool forest_ = false;
// The number of nodes in the original digraph; i.e., n.
Index num_original_nodes_ = 0;
// The number of nodes in the initial digraph; i.e., n+1.
Index num_initial_nodes_ = 0;
// The maximum number of possible nodes in the digraph; i.e., 2n.
Index num_possible_nodes_ = 0;
// The number of nodes in the current digraph, which grows from n+1 to 2n.
Index num_current_nodes_ = 0;
// Column-major |num_initial_nodes_| x |num_current_nodes_| matrix of arcs,
// where rows and columns correspond to source and target nodes. Columns are
// added as cycles are contracted into new nodes.
//
// TODO(googleuser): It is possible to squeeze the nonexistent arcs out of each
// column and run the algorithm with each column being a sorted list (sorted
// by source node). This is in fact the suggested representation in Tarjan
// (1977). This won't improve the asymptotic runtime but still might improve
// speed in practice. I haven't done this because it adds complexity versus
// checking Arc::Exists() in a few loops. Try this out when we can benchmark
// this on real data.
std::vector<Arc> arcs_;
// Disjoint-set forests tracking the weakly-connected and strongly-connected
// components of the initial digraph, based on the arcs in |argmax_arcs_|.
// Weakly-connected components are used to detect cycles; strongly-connected
// components are used to detect self-loops.
DisjointSetForest<Index> weak_components_;
DisjointSetForest<Index> strong_components_;
// A disjoint-set forest that maps each node to the top-most contracted node
// that contains it. Nodes that have not been contracted map to themselves.
// NB: This disjoint-set forest does not use union by rank so we can control
// the outcome of a set union. There will only be O(n) operations on this
// instance, so the increased O(log n) cost of each operation is acceptable.
DisjointSetForest<Index, false> contracted_nodes_;
// An array that represents the history of cycle contractions, as follows:
// * If contracted_into_[t] is |kNullIndex|, then t is deleted.
// * If contracted_into_[t] is 0, then t is a "root" contracted node; i.e., t
// has not been contracted into another node.
// * Otherwise, contracted_into_[t] is the node into which t was contracted.
std::vector<Index> contracted_into_;
// The maximum inbound arc for each node. The first element is null because
// the artificial root has no inbound arcs.
std::vector<const Arc *> argmax_arcs_;
// Workspace for ContractCycle(), which records the nodes and arcs in the
// cycle being contracted.
std::vector<std::pair<Index, const Arc *>> cycle_;
};
// Implementation details below.
template <class Index, class Score>
tensorflow::Status MstSolver<Index, Score>::Init(bool forest, Index num_nodes) {
if (num_nodes <= 0) {
return tensorflow::errors::InvalidArgument("Non-positive number of nodes: ",
num_nodes);
}
// Upcast to size_t to avoid overflow.
if (2 * static_cast<size_t>(num_nodes) >= static_cast<size_t>(kNullIndex)) {
return tensorflow::errors::InvalidArgument("Too many nodes: ", num_nodes);
}
forest_ = forest;
num_original_nodes_ = num_nodes;
num_initial_nodes_ = num_original_nodes_ + 1;
num_possible_nodes_ = 2 * num_original_nodes_;
num_current_nodes_ = num_initial_nodes_;
// Allocate the full n+1 x 2n matrix, but start with a n+1 x n+1 prefix.
const size_t num_initial_arcs = static_cast<size_t>(num_initial_nodes_) *
static_cast<size_t>(num_initial_nodes_);
const size_t num_possible_arcs = static_cast<size_t>(num_initial_nodes_) *
static_cast<size_t>(num_possible_nodes_);
arcs_.reserve(num_possible_arcs);
arcs_.assign(num_initial_arcs, {});
weak_components_.Init(num_initial_nodes_);
strong_components_.Init(num_initial_nodes_);
contracted_nodes_.Init(num_possible_nodes_);
contracted_into_.assign(num_possible_nodes_, 0);
argmax_arcs_.assign(num_possible_nodes_, nullptr);
// This doesn't need to be cleared now; it will be cleared before use.
cycle_.reserve(num_original_nodes_);
return tensorflow::Status::OK();
}
template <class Index, class Score>
void MstSolver<Index, Score>::AddArc(Index source, Index target, Score score) {
DCHECK_NE(source, target);
DCHECK(std::isfinite(score));
Arc &arc = arcs_[ArcIndex(source + 1, target + 1)];
arc.score = score;
arc.source = source + 1;
arc.target = target + 1;
}
template <class Index, class Score>
void MstSolver<Index, Score>::AddRoot(Index root, Score score) {
DCHECK(std::isfinite(score));
Arc &arc = arcs_[ArcIndex(0, root + 1)];
arc.score = score;
arc.source = 0;
arc.target = root + 1;
}
template <class Index, class Score>
tensorflow::Status MstSolver<Index, Score>::Solve(
tensorflow::gtl::MutableArraySlice<Index> argmax) {
MaybePenalizeRootScoresForTree();
TF_RETURN_IF_ERROR(ContractionPhase());
TF_RETURN_IF_ERROR(ExpansionPhase(argmax));
return tensorflow::Status::OK();
}
template <class Index, class Score>
inline size_t MstSolver<Index, Score>::ArcIndex(size_t source,
size_t target) const {
DCHECK_LT(source, num_initial_nodes_);
DCHECK_LT(target, num_current_nodes_);
return source + target * static_cast<size_t>(num_initial_nodes_);
}
template <class Index, class Score>
void MstSolver<Index, Score>::MaybePenalizeRootScoresForTree() {
if (forest_) return;
DCHECK_EQ(num_current_nodes_, num_initial_nodes_)
<< "Root penalties must be applied before starting the algorithm.";
// Find the minimum and maximum arc scores. These allow us to bound the range
// of possible tree scores.
Score max_score = std::numeric_limits<Score>::lowest();
Score min_score = std::numeric_limits<Score>::max();
for (const Arc &arc : arcs_) {
if (!arc.Exists()) continue;
max_score = std::max(max_score, arc.score);
min_score = std::min(min_score, arc.score);
}
// Nothing to do, no existing arcs.
if (max_score < min_score) return;
// A spanning tree or forest contains n arcs. The penalty below ensures that
// every structure with one root has a higher score than every structure with
// two roots, and so on.
const Score root_penalty = 1 + num_initial_nodes_ * (max_score - min_score);
for (Index root = 1; root < num_initial_nodes_; ++root) {
Arc &arc = arcs_[ArcIndex(0, root)];
if (!arc.Exists()) continue;
arc.score -= root_penalty;
}
}
template <class Index, class Score>
const typename MstSolver<Index, Score>::Arc *
MstSolver<Index, Score>::MaximumInboundArc(Index node) const {
const Arc *__restrict arc = &arcs_[ArcIndex(0, node)];
const Arc *arc_end = arc + num_initial_nodes_;
Score max_score = std::numeric_limits<Score>::lowest();
const Arc *argmax_arc = nullptr;
for (; arc < arc_end; ++arc) {
if (!arc->Exists()) continue;
const Score score = arc->score;
if (max_score <= score) {
max_score = score;
argmax_arc = arc;
}
}
return argmax_arc;
}
template <class Index, class Score>
void MstSolver<Index, Score>::MergeInboundArcs(Index cycle_node,
Score score_offset,
Index contracted_node) {
const Arc *__restrict cycle_arc = &arcs_[ArcIndex(0, cycle_node)];
const Arc *cycle_arc_end = cycle_arc + num_initial_nodes_;
Arc *__restrict contracted_arc = &arcs_[ArcIndex(0, contracted_node)];
for (; cycle_arc < cycle_arc_end; ++cycle_arc, ++contracted_arc) {
if (!cycle_arc->Exists()) continue; // nothing to merge
// Skip self-loops; they are useless because they cannot be used to break
// the cycle represented by the |contracted_node|.
if (strong_components_.SameSet(cycle_arc->source, cycle_arc->target)) {
continue;
}
// Merge the |cycle_arc| into the |contracted_arc|.
const Score cycle_score = cycle_arc->score + score_offset;
if (!contracted_arc->Exists() || contracted_arc->score < cycle_score) {
contracted_arc->score = cycle_score;
contracted_arc->source = cycle_arc->source;
contracted_arc->target = cycle_arc->target;
}
}
}
template <class Index, class Score>
void MstSolver<Index, Score>::ContractCycle(Index node) {
// Append a new node for the contracted cycle.
const Index contracted_node = num_current_nodes_++;
DCHECK_LE(num_current_nodes_, num_possible_nodes_);
arcs_.resize(arcs_.size() + num_initial_nodes_);
// We make two passes through the cycle. The first pass updates everything
// except the |arcs_|, and the second pass updates the |arcs_|. The |arcs_|
// must be updated in a second pass because MergeInboundArcs() requires that
// the |strong_components_| are updated with the newly-contracted cycle.
cycle_.clear();
Index cycle_node = node;
do {
// Gather the nodes and arcs in |cycle_| for the second pass.
const Arc *cycle_arc = argmax_arcs_[cycle_node];
DCHECK(!cycle_arc->IsRoot()) << cycle_arc->DebugString();
cycle_.emplace_back(cycle_node, cycle_arc);
// Mark the cycle nodes as members of a strongly-connected component.
strong_components_.Union(cycle_arc->source, cycle_arc->target);
// Mark the cycle nodes as members of the new contracted node. Juggling is
// required because |contracted_nodes_| also determines the next cycle node.
const Index next_node = contracted_nodes_.FindRoot(cycle_arc->source);
contracted_nodes_.UnionOfRoots(cycle_node, contracted_node);
contracted_into_[cycle_node] = contracted_node;
cycle_node = next_node;
// When the cycle repeats, |cycle_node| will be equal to |contracted_node|,
// not |node|, because the first iteration of this loop mapped |node| to
// |contracted_node| in |contracted_nodes_|.
} while (cycle_node != contracted_node);
// Merge the inbound arcs of each cycle node into the |contracted_node|.
for (const auto &node_and_arc : cycle_) {
// Set the |score_offset| to the cost of breaking the cycle by replacing the
// arc currently directed into the |cycle_node|.
const Index cycle_node = node_and_arc.first;
const Score score_offset = -node_and_arc.second->score;
MergeInboundArcs(cycle_node, score_offset, contracted_node);
}
}
template <class Index, class Score>
tensorflow::Status MstSolver<Index, Score>::ContractionPhase() {
// Skip the artificial root since it has no inbound arcs.
for (Index target = 1; target < num_current_nodes_; ++target) {
// Find the maximum inbound arc for the current |target|, if any.
const Arc *arc = MaximumInboundArc(target);
if (arc == nullptr) {
return tensorflow::errors::FailedPrecondition("Infeasible digraph");
}
argmax_arcs_[target] = arc;
// The articifial root cannot be part of a cycle, so we do not need to check
// for cycles or even update its membership in the connected components.
if (arc->IsRoot()) continue;
// Since every node has at most one selected inbound arc, cycles can be
// detected using weakly-connected components.
const Index source_component = weak_components_.FindRoot(arc->source);
const Index target_component = weak_components_.FindRoot(arc->target);
if (source_component == target_component) {
// Cycle detected; contract it into a new node.
ContractCycle(target);
} else {
// No cycles, just update the weakly-connected components.
weak_components_.UnionOfRoots(source_component, target_component);
}
}
return tensorflow::Status::OK();
}
template <class Index, class Score>
tensorflow::Status MstSolver<Index, Score>::ExpansionPhase(
tensorflow::gtl::MutableArraySlice<Index> argmax) {
if (argmax.size() < num_original_nodes_) {
return tensorflow::errors::InvalidArgument(
"Argmax array too small: ", num_original_nodes_,
" elements required, but got ", argmax.size());
}
// Select and expand a root contracted node until no contracted nodes remain.
// Thanks to the (topological) order in which contracted nodes are appended,
// root contracted nodes are easily enumerated using a backward scan. After
// this loop, entries [1,n] of |argmax_arcs_| provide the arcs of the maximum
// spanning tree.
for (Index i = num_current_nodes_ - 1; i >= num_initial_nodes_; --i) {
if (contracted_into_[i] == kNullIndex) continue; // already deleted
const Index root = i; // if not deleted, must be a root due to toposorting
// Copy the cycle-breaking arc to its specified target.
const Arc *arc = argmax_arcs_[root];
argmax_arcs_[arc->target] = arc;
// The |arc| not only breaks the cycle associated with the |root|, but also
// breaks every nested cycle between the |root| and the target of the |arc|.
// Delete the contracted nodes corresponding to all broken cycles.
Index node = contracted_into_[arc->target];
while (node != kNullIndex && node != root) {
const Index parent = contracted_into_[node];
contracted_into_[node] = kNullIndex;
node = parent;
}
}
// Copy the spanning tree from |argmax_arcs_| to |argmax|. Also count roots
// for validation below.
Index num_roots = 0;
for (Index target = 0; target < num_original_nodes_; ++target) {
const Arc &arc = *argmax_arcs_[target + 1];
DCHECK_EQ(arc.target, target + 1) << arc.DebugString();
if (arc.IsRoot()) {
++num_roots;
argmax[target] = target;
} else {
argmax[target] = arc.source - 1;
}
}
DCHECK_GE(num_roots, 1);
// Even when |forest_| is false, |num_roots| can still be more than 1. While
// the root score penalty discourages structures with multiple root arcs, it
// is not a hard constraint. For example, if the original digraph contained
// one root selection per node and no other arcs, the solver would incorrectly
// produce an all-root structure in spite of the root score penalty. As this
// example illustrates, however, |num_roots| will be more than 1 if and only
// if the original digraph is infeasible for trees.
if (!forest_ && num_roots != 1) {
return tensorflow::errors::FailedPrecondition("Infeasible digraph");
}
return tensorflow::Status::OK();
}
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_MST_MST_SOLVER_H_
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/mst_solver.h"
#include <time.h>
#include <random>
#include <set>
#include <vector>
#include "dragnn/mst/spanning_tree_iterator.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace {
using ::testing::Contains;
// Returns the random seed, or 0 for a weak random seed.
int64 GetSeed() {
return 1; // use a deterministic seed
}
// Returns the number of trials to run for each random comparison.
int64 GetNumTrials() {
return 3;
}
// Testing rig. Runs a comparison between a brute-force MST solver and the
// MstSolver<> on random digraphs. When the first test parameter is true,
// solves for forests instead of trees. The second test parameter defines the
// size of the test digraph.
class MstSolverRandomComparisonTest
: public ::testing::TestWithParam<::testing::tuple<bool, uint32>> {
protected:
// Use integer scores so score comparisons are exact.
using Solver = MstSolver<uint32, int32>;
// An array providing a source node for each node. Roots are self-loops.
using SourceList = SpanningTreeIterator::SourceList;
// A row-major n x n matrix whose i,j entry gives the score of the arc from i
// to j, and whose i,i entry gives the score of selecting i as a root.
using ScoreMatrix = std::vector<int32>;
// Returns true if this should be a forest.
bool forest() const { return ::testing::get<0>(GetParam()); }
// Returns the number of nodes for digraphs.
uint32 num_nodes() const { return ::testing::get<1>(GetParam()); }
// Returns the score of the arcs in |sources| based on the |scores|.
int32 ScoreArcs(const ScoreMatrix &scores, const SourceList &sources) const {
CHECK_EQ(num_nodes() * num_nodes(), scores.size());
int32 score = 0;
for (uint32 target = 0; target < num_nodes(); ++target) {
const uint32 source = sources[target];
score += scores[target + source * num_nodes()];
}
return score;
}
// Returns the score of the maximum spanning tree (or forest, if the first
// test parameter is true) of the dense digraph defined by the |scores|, and
// sets |argmax_trees| to contain all maximal trees.
int32 RunBruteForceMstSolver(const ScoreMatrix &scores,
std::set<SourceList> *argmax_trees) {
CHECK_EQ(num_nodes() * num_nodes(), scores.size());
int32 max_score;
argmax_trees->clear();
iterator_.ForEachTree(num_nodes(), [&](const SourceList &sources) {
const int32 score = ScoreArcs(scores, sources);
if (argmax_trees->empty() || max_score < score) {
max_score = score;
argmax_trees->clear();
argmax_trees->insert(sources);
} else if (max_score == score) {
argmax_trees->insert(sources);
}
});
return max_score;
}
// As above, but uses the |solver_| and extracts only one |argmax_tree|.
int32 RunMstSolver(const ScoreMatrix &scores, SourceList *argmax_tree) {
CHECK_EQ(num_nodes() * num_nodes(), scores.size());
TF_CHECK_OK(solver_.Init(forest(), num_nodes()));
// Add all roots and arcs.
for (uint32 source = 0; source < num_nodes(); ++source) {
for (uint32 target = 0; target < num_nodes(); ++target) {
const int32 score = scores[target + source * num_nodes()];
if (source == target) {
solver_.AddRoot(target, score);
} else {
solver_.AddArc(source, target, score);
}
}
}
// Solve for the max spanning tree.
argmax_tree->resize(num_nodes());
TF_CHECK_OK(solver_.Solve(argmax_tree));
return ScoreArcs(scores, *argmax_tree);
}
// Returns a random ScoreMatrix spanning num_nodes() nodes.
ScoreMatrix RandomScores() {
ScoreMatrix scores(num_nodes() * num_nodes());
for (int32 &value : scores) value = static_cast<int32>(prng_() % 201) - 100;
return scores;
}
// Runs a comparison between MstSolver and BruteForceMst on random digraphs of
// num_nodes() nodes, for the specified number of trials.
void RunComparison() {
// Seed the PRNG, possibly non-deterministically. Log the seed value so the
// test results can be reproduced, even when the seed is non-deterministic.
uint32 seed = GetSeed();
if (seed == 0) seed = time(nullptr);
prng_.seed(seed);
LOG(INFO) << "seed = " << seed;
const int num_trials = GetNumTrials();
for (int trial = 0; trial < num_trials; ++trial) {
const ScoreMatrix scores = RandomScores();
std::set<SourceList> expected_argmax_trees;
const int32 expected_max_score =
RunBruteForceMstSolver(scores, &expected_argmax_trees);
SourceList actual_argmax_tree;
const int32 actual_max_score = RunMstSolver(scores, &actual_argmax_tree);
// In case of ties, MstSolver will find a maximal spanning tree, but we
// don't know which one.
EXPECT_EQ(expected_max_score, actual_max_score);
ASSERT_THAT(expected_argmax_trees, Contains(actual_argmax_tree));
}
}
// Tree iterator for brute-force solver.
SpanningTreeIterator iterator_{forest()};
// MstSolver<> instance used by the test. Reused across all MST invocations
// to exercise reuse.
Solver solver_;
// Pseudo-random number generator.
std::mt19937 prng_;
};
INSTANTIATE_TEST_CASE_P(AllowForest, MstSolverRandomComparisonTest,
::testing::Combine(::testing::Bool(),
::testing::Range<uint32>(1, 9)));
TEST_P(MstSolverRandomComparisonTest, Comparison) { RunComparison(); }
} // namespace
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/mst_solver.h"
#include <limits>
#include <utility>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace {
using ::testing::HasSubstr;
// Testing rig.
//
// Template args:
// Solver: An instantiation of the MstSolver<> template.
template<class Solver>
class MstSolverTest : public ::testing::Test {
protected:
using Index = typename Solver::IndexType;
using Score = typename Solver::ScoreType;
// Adds directed arcs for all |num_nodes| nodes to the |solver_| with the
// |score|.
void AddAllArcs(Index num_nodes, Score score) {
for (Index source = 0; source < num_nodes; ++source) {
for (Index target = 0; target < num_nodes; ++target) {
if (source == target) continue;
solver_.AddArc(source, target, score);
}
}
}
// Adds root selections for all |num_nodes| nodes to the |solver_| with the
// |score|.
void AddAllRoots(Index num_nodes, Score score) {
for (Index root = 0; root < num_nodes; ++root) {
solver_.AddRoot(root, score);
}
}
// Runs the |solver_| using an argmax array of size |argmax_array_size| and
// expects it to fail with an error message that matches |error_substr|.
void SolveAndExpectError(int argmax_array_size,
const string &error_message_substr) {
std::vector<Index> argmax(argmax_array_size);
EXPECT_THAT(solver_.Solve(&argmax),
test::IsErrorWithSubstr(error_message_substr));
}
// As above, but expects success. Does not assert anything about the solution
// produced by the solver.
void SolveAndExpectOk(int argmax_array_size) {
std::vector<Index> argmax(argmax_array_size);
TF_EXPECT_OK(solver_.Solve(&argmax));
}
// As above, but expects the solution to be |expected_argmax| and infers the
// argmax array size.
void SolveAndExpectArgmax(const std::vector<Index> &expected_argmax) {
std::vector<Index> actual_argmax(expected_argmax.size());
TF_ASSERT_OK(solver_.Solve(&actual_argmax));
EXPECT_EQ(expected_argmax, actual_argmax);
}
// MstSolver<> instance used by the test. Reused across all MST problems in
// each test to exercise reuse.
Solver solver_;
};
using Solvers =
::testing::Types<MstSolver<uint8, int16>, MstSolver<uint16, int32>,
MstSolver<uint32, int64>, MstSolver<uint16, float>,
MstSolver<uint32, double>>;
TYPED_TEST_CASE(MstSolverTest, Solvers);
TYPED_TEST(MstSolverTest, FailIfNoNodes) {
for (const bool forest : {false, true}) {
EXPECT_THAT(this->solver_.Init(forest, 0),
test::IsErrorWithSubstr("Non-positive number of nodes"));
}
}
TYPED_TEST(MstSolverTest, FailIfTooManyNodes) {
// Set to a value that would overflow when doubled.
const auto kNumNodes =
(std::numeric_limits<typename TypeParam::IndexType>::max() / 2) + 10;
for (const bool forest : {false, true}) {
EXPECT_THAT(this->solver_.Init(forest, kNumNodes),
test::IsErrorWithSubstr("Too many nodes"));
}
}
TYPED_TEST(MstSolverTest, InfeasibleIfNoRootsNoArcs) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->SolveAndExpectError(kNumNodes, "Infeasible digraph");
}
}
TYPED_TEST(MstSolverTest, InfeasibleIfNoRootsAllArcs) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllArcs(kNumNodes, 0);
this->SolveAndExpectError(kNumNodes, "Infeasible digraph");
}
}
TYPED_TEST(MstSolverTest, FeasibleForForestOnlyIfAllRootsNoArcs) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
if (forest) {
this->SolveAndExpectOk(kNumNodes); // all roots is a valid forest
} else {
this->SolveAndExpectError(kNumNodes, "Infeasible digraph");
}
}
}
TYPED_TEST(MstSolverTest, FeasibleIfAllRootsAllArcs) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
this->SolveAndExpectOk(kNumNodes);
}
}
TYPED_TEST(MstSolverTest, FailIfArgmaxArrayTooSmall) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
this->SolveAndExpectError(kNumNodes - 1, // too small
"Argmax array too small");
}
}
TYPED_TEST(MstSolverTest, OkIfArgmaxArrayTooLarge) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
this->SolveAndExpectOk(kNumNodes + 1); // too large
}
}
TYPED_TEST(MstSolverTest, SolveForAllRootsForestOnly) {
const int kNumNodes = 10;
const bool forest = true;
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 1); // favor all root selections
this->AddAllArcs(kNumNodes, 0);
this->SolveAndExpectArgmax({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
}
TYPED_TEST(MstSolverTest, SolveForLeftToRightChain) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
for (int target = 1; target < kNumNodes; ++target) {
this->solver_.AddArc(target - 1, target, 1); // favor left-to-right chain
}
this->SolveAndExpectArgmax({0, 0, 1, 2, 3, 4, 5, 6, 7, 8});
}
}
TYPED_TEST(MstSolverTest, SolveForRightToLeftChain) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
for (int source = 1; source < kNumNodes; ++source) {
this->solver_.AddArc(source, source - 1, 1); // favor right-to-left chain
}
this->SolveAndExpectArgmax({1, 2, 3, 4, 5, 6, 7, 8, 9, 9});
}
}
TYPED_TEST(MstSolverTest, SolveForAllFromFirstTree) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
for (int target = 1; target < kNumNodes; ++target) {
this->solver_.AddArc(0, target, 1); // favor first -> target
}
this->SolveAndExpectArgmax({0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
}
}
TYPED_TEST(MstSolverTest, SolveForAllFromLastTree) {
const int kNumNodes = 10;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
for (int target = 0; target + 1 < kNumNodes; ++target) {
this->solver_.AddArc(9, target, 1); // favor last -> target
}
this->SolveAndExpectArgmax({9, 9, 9, 9, 9, 9, 9, 9, 9, 9});
}
}
TYPED_TEST(MstSolverTest, SolveForBinaryTree) {
const int kNumNodes = 15;
for (const bool forest : {false, true}) {
TF_ASSERT_OK(this->solver_.Init(forest, kNumNodes));
this->AddAllRoots(kNumNodes, 0);
this->AddAllArcs(kNumNodes, 0);
for (int target = 1; target < kNumNodes; ++target) {
this->solver_.AddArc((target - 1) / 2, target, 1); // like a binary heap
}
this->SolveAndExpectArgmax({0,
0, 0,
1, 1, 2, 2,
3, 3, 4, 4, 5, 5, 6, 6});
}
}
} // namespace
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <cmath>
#include <limits>
#include <type_traits>
#include <vector>
#include "dragnn/mst/mst_solver.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
namespace syntaxnet {
namespace dragnn {
// Op kernel implementation that wraps the |MstSolver|.
template <class Index, class Score>
class MaximumSpanningTreeOpKernel : public tensorflow::OpKernel {
public:
explicit MaximumSpanningTreeOpKernel(
tensorflow::OpKernelConstruction *context)
: tensorflow::OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("forest", &forest_));
}
void Compute(tensorflow::OpKernelContext *context) override {
const tensorflow::Tensor &num_nodes_tensor = context->input(0);
const tensorflow::Tensor &scores_tensor = context->input(1);
// Check ranks.
OP_REQUIRES(context, num_nodes_tensor.dims() == 1,
tensorflow::errors::InvalidArgument(
"num_nodes must be a vector, got shape ",
num_nodes_tensor.shape().DebugString()));
OP_REQUIRES(context, scores_tensor.dims() == 3,
tensorflow::errors::InvalidArgument(
"scores must be rank 3, got shape ",
scores_tensor.shape().DebugString()));
// Batch size and input dimension (B and M in the op docstring).
const int64 batch_size = scores_tensor.shape().dim_size(0);
const int64 input_dim = scores_tensor.shape().dim_size(1);
// Check shapes.
const tensorflow::TensorShape shape_b({batch_size});
const tensorflow::TensorShape shape_bxm({batch_size, input_dim});
const tensorflow::TensorShape shape_bxmxm(
{batch_size, input_dim, input_dim});
OP_REQUIRES(
context, num_nodes_tensor.shape() == shape_b,
tensorflow::errors::InvalidArgument(
"num_nodes misshapen: got ", num_nodes_tensor.shape().DebugString(),
" but expected ", shape_b.DebugString()));
OP_REQUIRES(
context, scores_tensor.shape() == shape_bxmxm,
tensorflow::errors::InvalidArgument(
"scores misshapen: got ", scores_tensor.shape().DebugString(),
" but expected ", shape_bxmxm.DebugString()));
// Create outputs.
tensorflow::Tensor *max_scores_tensor = nullptr;
tensorflow::Tensor *argmax_sources_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, shape_b, &max_scores_tensor));
OP_REQUIRES_OK(context, context->allocate_output(1, shape_bxm,
&argmax_sources_tensor));
// Acquire shaped and typed references.
const BatchedSizes num_nodes_b = num_nodes_tensor.vec<int32>();
const BatchedScores scores_bxmxm = scores_tensor.tensor<Score, 3>();
BatchedMaxima max_scores_b = max_scores_tensor->vec<Score>();
BatchedSources argmax_sources_bxm = argmax_sources_tensor->matrix<int32>();
// Solve the batch of MST problems in parallel. Set a high cycles per unit
// to encourage finer sharding.
constexpr int64 kCyclesPerUnit = 1000 * 1000 * 1000;
std::vector<tensorflow::Status> statuses(batch_size);
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
batch_size, kCyclesPerUnit, [&](int64 begin, int64 end) {
for (int64 problem = begin; problem < end; ++problem) {
statuses[problem] = RunSolver(problem, num_nodes_b, scores_bxmxm,
max_scores_b, argmax_sources_bxm);
}
});
for (const tensorflow::Status &status : statuses) {
OP_REQUIRES_OK(context, status);
}
}
private:
using BatchedSizes = typename tensorflow::TTypes<int32>::ConstVec;
using BatchedScores = typename tensorflow::TTypes<Score, 3>::ConstTensor;
using BatchedMaxima = typename tensorflow::TTypes<Score>::Vec;
using BatchedSources = typename tensorflow::TTypes<int32>::Matrix;
// Solves for the maximum spanning tree of the digraph defined by the values
// at index |problem| in |num_nodes_b| and |scores_bxmxm|. On success, sets
// the values at index |problem| in |max_scores_b| and |argmax_sources_bxm|.
// On error, returns non-OK.
tensorflow::Status RunSolver(int problem, BatchedSizes num_nodes_b,
BatchedScores scores_bxmxm,
BatchedMaxima max_scores_b,
BatchedSources argmax_sources_bxm) const {
// Check digraph size overflow.
const int32 num_nodes = num_nodes_b(problem);
const int32 input_dim = argmax_sources_bxm.dimension(1);
if (num_nodes > input_dim) {
return tensorflow::errors::InvalidArgument(
"number of nodes in digraph ", problem,
" overflows input dimension: got ", num_nodes,
" but expected <= ", input_dim);
}
if (num_nodes >= std::numeric_limits<Index>::max()) {
return tensorflow::errors::InvalidArgument(
"number of nodes in digraph ", problem, " overflows index type: got ",
num_nodes, " but expected < ", std::numeric_limits<Index>::max());
}
const Index num_nodes_index = static_cast<Index>(num_nodes);
MstSolver<Index, Score> solver;
TF_RETURN_IF_ERROR(solver.Init(forest_, num_nodes_index));
// Populate the solver with arcs and root selections. Note that non-finite
// scores are treated as nonexistent arcs or roots.
for (Index target = 0; target < num_nodes_index; ++target) {
for (Index source = 0; source < num_nodes_index; ++source) {
const Score score = scores_bxmxm(problem, target, source);
if (!std::isfinite(score)) continue;
if (source == target) { // root
solver.AddRoot(target, score);
} else { // arc
solver.AddArc(source, target, score);
}
}
}
std::vector<Index> argmax(num_nodes);
TF_RETURN_IF_ERROR(solver.Solve(&argmax));
// Output the tree and accumulate its score.
Score max_score = 0;
for (Index target = 0; target < num_nodes_index; ++target) {
const Index source = argmax[target];
argmax_sources_bxm(problem, target) = source;
max_score += scores_bxmxm(problem, target, source);
}
max_scores_b(problem) = max_score;
// Pad the source list with -1.
for (int32 i = num_nodes; i < input_dim; ++i) {
argmax_sources_bxm(problem, i) = -1;
}
return tensorflow::Status::OK();
}
private:
bool forest_ = false;
};
// Use Index=uint16, which allows digraphs containing up to 32,767 nodes.
REGISTER_KERNEL_BUILDER(Name("MaximumSpanningTree")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<int32>("T"),
MaximumSpanningTreeOpKernel<uint16, int32>);
REGISTER_KERNEL_BUILDER(Name("MaximumSpanningTree")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<float>("T"),
MaximumSpanningTreeOpKernel<uint16, float>);
REGISTER_KERNEL_BUILDER(Name("MaximumSpanningTree")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<double>("T"),
MaximumSpanningTreeOpKernel<uint16, double>);
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace syntaxnet {
namespace dragnn {
REGISTER_OP("MaximumSpanningTree")
.Attr("T: {int32, float, double}")
.Attr("forest: bool = false")
.Input("num_nodes: int32")
.Input("scores: T")
.Output("max_scores: T")
.Output("argmax_sources: int32")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
tensorflow::shape_inference::ShapeHandle num_nodes;
tensorflow::shape_inference::ShapeHandle scores;
TF_RETURN_IF_ERROR(context->WithRank(context->input(0), 1, &num_nodes));
TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 3, &scores));
// Extract dimensions while asserting that they match.
tensorflow::shape_inference::DimensionHandle batch_size; // aka "B"
TF_RETURN_IF_ERROR(context->Merge(context->Dim(num_nodes, 0),
context->Dim(scores, 0), &batch_size));
tensorflow::shape_inference::DimensionHandle max_nodes; // aka "M"
TF_RETURN_IF_ERROR(context->Merge(context->Dim(scores, 1),
context->Dim(scores, 2), &max_nodes));
context->set_output(0, context->Vector(batch_size));
context->set_output(1, context->Matrix(batch_size, max_nodes));
return tensorflow::Status::OK();
})
.Doc(R"doc(
Finds the maximum directed spanning tree of a digraph.
Given a batch of digraphs with scored arcs and root selections, solves for the
maximum spanning tree of each digraph, where the score of a tree is defined as
the sum of the scores of the arcs and roots making up the tree.
Returns the score of the maximum spanning tree of each digraph, as well as the
arcs and roots in that tree. Each digraph in a batch may contain a different
number of nodes, so the sizes of the digraphs must be provided as an input.
Note that this operation is only differentiable w.r.t. its |scores| input and
its |max_scores| output.
forest: If true, solves for a maximum spanning forest instead of a maximum
spanning tree, where a spanning forest is a set of disjoint trees that
span the nodes of the digraph.
num_nodes: [B] vector where entry b is number of nodes in the b'th digraph.
scores: [B,M,M] tensor where entry b,t,s is the score of the arc from s to t in
the b'th digraph, if s!=t, or the score of selecting t as a root in the
b'th digraph, if s==t. Requires that M is >= num_nodes[b], for all b,
and ignores entries b,s,t where s or t is >= num_nodes[b]. Arcs or root
selections with non-finite score are treated as nonexistent.
max_scores: [B] vector where entry b is the score of the maximum spanning tree
of the b'th digraph.
argmax_sources: [B,M] matrix where entry b,t is the source of the arc inbound to
t in the maximum spanning tree of the b'th digraph, or t if t is
a root. Entries b,t where t is >= num_nodes[b] are set to -1.
)doc");
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/spanning_tree_iterator.h"
namespace syntaxnet {
namespace dragnn {
SpanningTreeIterator::SpanningTreeIterator(bool forest) : forest_(forest) {}
bool SpanningTreeIterator::HasCycle(const SourceList &sources) {
// Flags for whether each node has already been searched.
searched_.assign(sources.size(), false);
// Flags for whether the search is currently visiting each node.
visiting_.assign(sources.size(), false);
// Search upwards from each node to find cycles.
for (uint32 initial_node = 0; initial_node < sources.size(); ++initial_node) {
// Search upwards to try to find a cycle.
uint32 current_node = initial_node;
while (true) {
if (searched_[current_node]) break; // already searched
if (visiting_[current_node]) return true; // revisiting implies cycle
visiting_[current_node] = true; // mark as being currently visited
const uint32 source_node = sources[current_node];
if (source_node == current_node) break; // self-loops are roots
current_node = source_node; // advance upwards
}
// No cycle; search upwards again to update flags.
current_node = initial_node;
while (true) {
if (searched_[current_node]) break; // already searched
searched_[current_node] = true;
visiting_[current_node] = false;
const uint32 source_node = sources[current_node];
if (source_node == current_node) break; // self-loops are roots
current_node = source_node; // advance upwards
}
}
return false;
}
uint32 SpanningTreeIterator::NumRoots(const SourceList &sources) {
uint32 num_roots = 0;
for (uint32 node = 0; node < sources.size(); ++node) {
num_roots += (node == sources[node]);
}
return num_roots;
}
bool SpanningTreeIterator::NextSourceList(SourceList *sources) {
const uint32 num_nodes = sources->size();
for (uint32 i = 0; i < num_nodes; ++i) {
const uint32 new_source = ++(*sources)[i];
if (new_source < num_nodes) return true; // absorbed in this digit
(*sources)[i] = 0; // overflowed this digit, carry to next digit
}
return false; // overflowed the last digit
}
bool SpanningTreeIterator::NextTree(SourceList *sources) {
// Iterate source lists, skipping non-trees.
while (NextSourceList(sources)) {
// Check the number of roots.
const uint32 num_roots = NumRoots(*sources);
if (forest_) {
if (num_roots == 0) continue;
} else {
if (num_roots != 1) continue;
}
// Check for cycles.
if (HasCycle(*sources)) continue;
// Acyclic and rooted, therefore tree.
return true;
}
return false;
}
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_MST_SPANNING_TREE_ITERATOR_H_
#define DRAGNN_MST_SPANNING_TREE_ITERATOR_H_
#include <vector>
#include "syntaxnet/base.h"
namespace syntaxnet {
namespace dragnn {
// A class that iterates over all possible spanning trees of a complete digraph.
// Thread-compatible. Useful for brute-force comparison tests.
//
// TODO(googleuser): Try using Prufer sequences, which are more efficient to
// enumerate as there are no non-trees to filter out.
class SpanningTreeIterator {
public:
// An array that provides the source of the inbound arc for each node. Roots
// are represented as self-loops.
using SourceList = std::vector<uint32>;
// Creates a spanning tree iterator. If |forest| is true, then this iterates
// over forests instead of trees (i.e., multiple roots are allowed).
explicit SpanningTreeIterator(bool forest);
// Applies the |functor| to all spanning trees (or forests, if |forest_| is
// true) of a complete digraph containing |num_nodes| nodes. Each tree is
// passed to the |functor| as a SourceList.
template <class Functor>
void ForEachTree(uint32 num_nodes, Functor functor) {
// Conveniently, the all-zero vector represents a valid tree.
SourceList sources(num_nodes, 0);
do {
functor(sources);
} while (NextTree(&sources));
}
private:
// Returns true if the |sources| contains a cycle.
bool HasCycle(const SourceList &sources);
// Returns the number of roots in the |sources|.
static uint32 NumRoots(const SourceList &sources);
// Advances |sources| to the next source list, or returns false if there are
// no more source lists.
static bool NextSourceList(SourceList *sources);
// Advances |sources| to the next tree (or forest, if |forest_| is true), or
// returns false if there are no more trees.
bool NextTree(SourceList *sources);
// If true, iterate over spanning forests instead of spanning trees.
const bool forest_;
// Workspaces used by the search in HasCycle().
std::vector<bool> searched_;
std::vector<bool> visiting_;
};
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_MST_SPANNING_TREE_ITERATOR_H_
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/spanning_tree_iterator.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace {
// Testing rig. When the bool parameter is true, iterates over spanning forests
// instead of spanning trees.
class SpanningTreeIteratorTest : public ::testing::TestWithParam<bool> {
protected:
using SourceList = SpanningTreeIterator::SourceList;
// Returns |base|^|exponent|. Computes the value as an integer to avoid
// rounding issues.
static int Pow(int base, int exponent) {
double real_product = 1.0;
int product = 1;
for (int i = 0; i < exponent; ++i) {
product *= base;
real_product *= base;
}
CHECK_EQ(product, real_product) << "Overflow detected.";
return product;
}
// Expects that the number of possible spanning trees for a complete digraph
// of |num_nodes| nodes is |expected_num_trees|.
void ExpectNumTrees(int num_nodes, int expected_num_trees) {
int actual_num_trees = 0;
iterator_.ForEachTree(
num_nodes, [&](const SourceList &sources) { ++actual_num_trees; });
LOG(INFO) << "num_nodes=" << num_nodes
<< " expected_num_trees=" << expected_num_trees
<< " actual_num_trees=" << actual_num_trees;
EXPECT_EQ(expected_num_trees, actual_num_trees);
}
// Expects that the set of possible spanning trees for a complete digraph of
// |num_nodes| nodes is |expected_trees|.
void ExpectTrees(int num_nodes, const std::set<SourceList> &expected_trees) {
std::set<SourceList> actual_trees;
iterator_.ForEachTree(num_nodes, [&](const SourceList &sources) {
CHECK(actual_trees.insert(sources).second);
});
EXPECT_EQ(expected_trees, actual_trees);
}
// Instance for tests. Shared across assertions in a test to exercise reuse.
SpanningTreeIterator iterator_{GetParam()};
};
INSTANTIATE_TEST_CASE_P(AllowForest, SpanningTreeIteratorTest,
::testing::Bool());
TEST_P(SpanningTreeIteratorTest, NumberOfTrees) {
// According to Cayley's formula, the number of undirected spanning trees on a
// complete graph of n nodes is n^{n-2}:
// https://en.wikipedia.org/wiki/Cayley%27s_formula
//
// To count the number of directed spanning trees, note that each undirected
// spanning tree gives rise to n directed spanning trees: choose one of the n
// nodes as the root, and then orient arcs outwards. Therefore, the number of
// directed spanning trees on a complete digraph of n nodes is n^{n-1}.
//
// To count the number of directed spanning forests, consider undirected
// spanning trees on a complete graph of n+1 nodes. Arbitrarily select one
// node as the artificial root, orient arcs outwards, and then delete the
// artificial root and its outbound arcs. The result is a directed spanning
// forest on n nodes. Therefore, the number of directed spanning forests on a
// complete digraph of n nodes is (n+1)^{n-1}.
for (int num_nodes = 1; num_nodes <= 7; ++num_nodes) {
if (GetParam()) { // forest
ExpectNumTrees(num_nodes, Pow(num_nodes + 1, num_nodes - 1));
} else { // tree
ExpectNumTrees(num_nodes, Pow(num_nodes, num_nodes - 1));
}
}
}
TEST_P(SpanningTreeIteratorTest, OneNodeDigraph) {
ExpectTrees(1, {{0}});
}
TEST_P(SpanningTreeIteratorTest, TwoNodeDigraph) {
if (GetParam()) { // forest
ExpectTrees(2, {{0, 0}, {0, 1}, {1, 1}}); // {0, 1} is two-root structure
} else { // tree
ExpectTrees(2, {{0, 0}, {1, 1}});
}
}
TEST_P(SpanningTreeIteratorTest, ThreeNodeDigraph) {
if (GetParam()) { // forest
ExpectTrees(3, {{0, 0, 0},
{0, 0, 1},
{0, 0, 2}, // 2-root
{0, 1, 0}, // 2-root
{0, 1, 1}, // 2-root
{0, 1, 2}, // 3-root
{0, 2, 0},
{0, 2, 2}, // 2-root
{1, 1, 0},
{1, 1, 1},
{1, 1, 2}, // 2-root
{1, 2, 2},
{2, 0, 2},
{2, 1, 1},
{2, 1, 2}, // 2-root
{2, 2, 2}});
} else { // tree
ExpectTrees(3, {{0, 0, 0},
{0, 0, 1},
{0, 2, 0},
{1, 1, 0},
{1, 1, 1},
{1, 2, 2},
{2, 0, 2},
{2, 1, 1},
{2, 2, 2}});
}
}
} // namespace
} // namespace dragnn
} // namespace syntaxnet
......@@ -2,48 +2,63 @@ package(default_visibility = ["//visibility:public"])
load(
"//syntaxnet:syntaxnet.bzl",
"tf_proto_library",
"tf_proto_library_cc",
"tf_proto_library_py",
)
# Protos.
tf_proto_library(
tf_proto_library_cc(
name = "data_proto",
srcs = ["data.proto"],
)
tf_proto_library(
tf_proto_library_cc(
name = "trace_proto",
srcs = ["trace.proto"],
deps = [
":data_proto",
],
protodeps = [":data_proto"],
)
tf_proto_library(
tf_proto_library_cc(
name = "cell_trace_proto",
srcs = ["cell_trace.proto"],
protodeps = [":trace_proto"],
)
tf_proto_library_cc(
name = "spec_proto",
srcs = ["spec.proto"],
)
tf_proto_library(
tf_proto_library_cc(
name = "runtime_proto",
srcs = ["runtime.proto"],
deps = [":spec_proto"],
protodeps = [":spec_proto"],
)
tf_proto_library_cc(
name = "export_proto",
srcs = ["export.proto"],
protodeps = [":spec_proto"],
)
tf_proto_library_py(
name = "data_py_pb2",
name = "data_pb2",
srcs = ["data.proto"],
)
tf_proto_library_py(
name = "trace_py_pb2",
name = "trace_pb2",
srcs = ["trace.proto"],
deps = [":data_py_pb2"],
protodeps = [":data_pb2"],
)
tf_proto_library_py(
name = "spec_py_pb2",
name = "spec_pb2",
srcs = ["spec.proto"],
)
tf_proto_library_py(
name = "export_pb2",
srcs = ["export.proto"],
)
syntax = "proto2";
import "dragnn/protos/trace.proto";
package syntaxnet.dragnn.runtime;
// Trace of a network cell computation (e.g., an LSTM cell).
// NEXT ID: 4
message CellTrace {
extend ComponentStepTrace {
// Cell computations that occurred in the step. It's possible that there is
// more than one cell per component (e.g., a bi-LSTM component).
repeated CellTrace step_trace_extension = 169167178;
}
// Name of the cell.
optional string name = 1;
// Tensors making up the cell. Note that this only includes local variables
// (e.g., activation vectors), not global constants (e.g., weight matrices).
repeated CellTensorTrace tensor = 2;
// Operations making up the cell. Note that the operation inputs may refer to
// global constants that are not present in |tensor|.
repeated CellOperationTrace operation = 3;
}
// Trace of a tensor in a cell computation.
// NEXT ID: 7
message CellTensorTrace {
// Possible orderings of the dimensions.
enum Order {
ORDER_UNKNOWN = 0; // unspecified or unknown
ORDER_ROW_MAJOR = 1; // row-major: dimension 0 has largest stride
ORDER_COLUMN_MAJOR = 2; // column-major: dimension 0 has smallest stride
}
// Name of the tensor (e.g., "annotation/inference_rnn/split:1").
optional string name = 1;
// Data type of the tensor (e.g., "DT_FLOAT").
optional string type = 2;
// Dimensions of the tensor (e.g., [1, 65]).
repeated int32 dimension = 3;
// Alignment-padded dimensions of the tensor (e.g., [1, 96]).
repeated int32 aligned_dimension = 4;
// Ordering of the tensor values.
optional Order order = 5 [default = ORDER_UNKNOWN];
// Block of alignment-padded values. For simplicity, values of all types are
// converted to double (via C++ conversion rules). Use |aligned_dimension| to
// traverse the values, but note that |dimension| bounds the valid region.
repeated double value = 6;
}
// Trace of an operation in a cell computation.
// NEXT ID: 6
message CellOperationTrace {
// Name of the operation (e.g., "annotation/inference_rnn/MatMul").
optional string name = 1;
// High-level type of the operation (e.g., "MatMul").
optional string type = 2;
// Kernel that implements the operation, if applicable (e.g., "AvxFltMatMul").
optional string kernel = 3;
// Names of input tensors of the operation, in order.
repeated string input = 4;
// Names of output tensors of the operation, in order.
repeated string output = 5;
}
// DRAGNN data proto. See go/dragnn-design for more information.
// DRAGNN data proto.
syntax = "proto2";
......
syntax = "proto2";
import "dragnn/protos/spec.proto";
package syntaxnet.dragnn.runtime;
// Specification of a subgraph of TF nodes that make up a network cell.
//
// Roughly speaking, a "cell" consists of the "pure math" parts of a DRAGNN
// component, and is intended to be exported to a NN compiler. The set of
// operations that make up a cell may change over time, but currently the
// boundaries of a cell are:
//
// Inputs:
// * Fixed feature IDs.
// * Linked feature embeddings, before pass_through_embedding_matrix().
// * Recurrent context tensors.
//
// Outputs:
// * Network unit layers.
message CellSubgraphSpec {
// An input to the subgraph.
message Input {
// Possible types of input.
enum Type {
TYPE_UNKNOWN = 0;
// An input derived from a fixed or linked feature.
TYPE_FEATURE = 1;
// An input that refers to an output of the previous iteration of the
// transition loop. The input must have the same name as the output to
// which it refers. On the first iteration, its value is zero.
//
// This is used by, e.g., LSTMNetwork, which reads its cell state from the
// context_tensor_arrays instead of from a linked feature.
TYPE_RECURRENT = 2;
}
// Logical name of the input (e.g., "lstm_c", "linked_feature_0"). Must be
// unique among the inputs of the cell.
optional string name = 1;
// Tensor containing the input (e.g., "annotation/rnn/split:1"). Must be
// unique among the inputs of the cell.
optional string tensor = 2;
// Type of input.
optional Type type = 3 [default = TYPE_UNKNOWN];
}
// An output of the subgraph.
message Output {
// Logical name of the output (e.g., "lstm_c", "layer_0"). Must be unique
// among the outputs of the cell.
optional string name = 1;
// Tensor containing the output (e.g., "annotation/rnn/split:1"). Need not
// be unique; duplicate outputs for the same tensor are treated as aliases.
optional string tensor = 2;
}
// Inputs of the subgraph.
repeated Input input = 1;
// Outputs of the subgraph.
repeated Output output = 2;
}
// Additional information to compile a component.
//
// NEXT ID: 3
message CompilationSpec {
extend ComponentSpec {
optional CompilationSpec component_spec_extension = 174770970;
}
// A unique name of the entire DRAGNN model where this component is used.
optional string model_name = 1;
// The subgraph specification for this component.
optional CellSubgraphSpec cell_subgraph_spec = 2;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment