Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
80178fc6
"docs/en/model_zoo.md" did not exist on "7f3a16a3e30267d90aa911e5714686b7112da0a6"
Unverified
Commit
80178fc6
authored
May 11, 2018
by
Mark Omernick
Committed by
GitHub
May 11, 2018
Browse files
Merge pull request #4153 from terryykoo/master
Export @195097388.
parents
a84e1ef9
edea2b67
Changes
145
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2316 additions
and
19 deletions
+2316
-19
research/syntaxnet/dragnn/core/test/mock_compute_session.h
research/syntaxnet/dragnn/core/test/mock_compute_session.h
+3
-2
research/syntaxnet/dragnn/core/util/BUILD
research/syntaxnet/dragnn/core/util/BUILD
+9
-0
research/syntaxnet/dragnn/core/util/label.h
research/syntaxnet/dragnn/core/util/label.h
+45
-0
research/syntaxnet/dragnn/io/BUILD
research/syntaxnet/dragnn/io/BUILD
+3
-3
research/syntaxnet/dragnn/mst/BUILD
research/syntaxnet/dragnn/mst/BUILD
+116
-0
research/syntaxnet/dragnn/mst/README.md
research/syntaxnet/dragnn/mst/README.md
+3
-0
research/syntaxnet/dragnn/mst/disjoint_set_forest.h
research/syntaxnet/dragnn/mst/disjoint_set_forest.h
+183
-0
research/syntaxnet/dragnn/mst/disjoint_set_forest_test.cc
research/syntaxnet/dragnn/mst/disjoint_set_forest_test.cc
+150
-0
research/syntaxnet/dragnn/mst/mst_solver.h
research/syntaxnet/dragnn/mst/mst_solver.h
+587
-0
research/syntaxnet/dragnn/mst/mst_solver_random_comparison_test.cc
...syntaxnet/dragnn/mst/mst_solver_random_comparison_test.cc
+183
-0
research/syntaxnet/dragnn/mst/mst_solver_test.cc
research/syntaxnet/dragnn/mst/mst_solver_test.cc
+255
-0
research/syntaxnet/dragnn/mst/ops/mst_op_kernels.cc
research/syntaxnet/dragnn/mst/ops/mst_op_kernels.cc
+193
-0
research/syntaxnet/dragnn/mst/ops/mst_ops.cc
research/syntaxnet/dragnn/mst/ops/mst_ops.cc
+78
-0
research/syntaxnet/dragnn/mst/spanning_tree_iterator.cc
research/syntaxnet/dragnn/mst/spanning_tree_iterator.cc
+97
-0
research/syntaxnet/dragnn/mst/spanning_tree_iterator.h
research/syntaxnet/dragnn/mst/spanning_tree_iterator.h
+79
-0
research/syntaxnet/dragnn/mst/spanning_tree_iterator_test.cc
research/syntaxnet/dragnn/mst/spanning_tree_iterator_test.cc
+143
-0
research/syntaxnet/dragnn/protos/BUILD
research/syntaxnet/dragnn/protos/BUILD
+28
-13
research/syntaxnet/dragnn/protos/cell_trace.proto
research/syntaxnet/dragnn/protos/cell_trace.proto
+76
-0
research/syntaxnet/dragnn/protos/data.proto
research/syntaxnet/dragnn/protos/data.proto
+2
-1
research/syntaxnet/dragnn/protos/export.proto
research/syntaxnet/dragnn/protos/export.proto
+83
-0
No files found.
research/syntaxnet/dragnn/core/test/mock_compute_session.h
View file @
80178fc6
...
...
@@ -62,13 +62,14 @@ class MockComputeSession : public ComputeSession {
MOCK_METHOD2
(
GetTranslatedLinkFeatures
,
std
::
vector
<
LinkFeatures
>
(
const
string
&
component_name
,
int
channel_id
));
MOCK_METHOD1
(
EmitOracleLabels
,
std
::
vector
<
std
::
vector
<
int
>>
(
const
string
&
component_name
));
MOCK_METHOD1
(
EmitOracleLabels
,
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
(
const
string
&
component_name
));
MOCK_METHOD1
(
IsTerminal
,
bool
(
const
string
&
component_name
));
MOCK_METHOD1
(
FinalizeData
,
void
(
const
string
&
component_name
));
MOCK_METHOD0
(
GetSerializedPredictions
,
std
::
vector
<
string
>
());
MOCK_METHOD0
(
GetTraceProtos
,
std
::
vector
<
MasterTrace
>
());
MOCK_METHOD1
(
SetInputData
,
void
(
const
std
::
vector
<
string
>
&
data
));
MOCK_METHOD0
(
GetInputBatchCache
,
InputBatchCache
*
());
MOCK_METHOD0
(
ResetSession
,
void
());
MOCK_METHOD1
(
SetTracing
,
void
(
bool
tracing_on
));
MOCK_CONST_METHOD0
(
Id
,
int
());
...
...
research/syntaxnet/dragnn/core/util/BUILD
0 → 100644
View file @
80178fc6
package
(
default_visibility
=
[
"//visibility:public"
],
features
=
[
"-layering_check"
],
)
cc_library
(
name
=
"label"
,
hdrs
=
[
"label.h"
],
)
research/syntaxnet/dragnn/core/util/label.h
0 → 100644
View file @
80178fc6
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_CORE_UTIL_LABEL_H_
#define DRAGNN_CORE_UTIL_LABEL_H_
#include <cmath>
namespace
syntaxnet
{
namespace
dragnn
{
// Stores label information.
struct
Label
{
Label
(
int
label_id
,
float
label_probability
)
:
id
(
label_id
),
probability
(
label_probability
)
{}
explicit
Label
(
int
label_id
)
:
id
(
label_id
)
{}
// Two Labels are equal if the ids match and the probabilities are within an
// epsilon of one another.
bool
operator
==
(
const
Label
&
label
)
const
{
return
(
id
==
label
.
id
)
&&
std
::
fabs
(
probability
-
label
.
probability
)
<
0.00001
;
}
// Label id and probability.
int
id
;
float
probability
=
1.0
;
};
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_CORE_UTIL_LABEL_H_
research/syntaxnet/dragnn/io/BUILD
View file @
80178fc6
...
...
@@ -8,7 +8,7 @@ cc_library(
":syntaxnet_sentence"
,
"//dragnn/core/interfaces:input_batch"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto"
,
"//syntaxnet:sentence_proto
_cc
"
,
],
)
...
...
@@ -16,7 +16,7 @@ cc_library(
name
=
"syntaxnet_sentence"
,
hdrs
=
[
"syntaxnet_sentence.h"
],
deps
=
[
"//syntaxnet:sentence_proto"
,
"//syntaxnet:sentence_proto
_cc
"
,
"//syntaxnet:workspace"
,
],
)
...
...
@@ -27,7 +27,7 @@ cc_test(
deps
=
[
":sentence_input_batch"
,
"//dragnn/core/test:generic"
,
"//syntaxnet:sentence_proto"
,
"//syntaxnet:sentence_proto
_cc
"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
...
...
research/syntaxnet/dragnn/mst/BUILD
0 → 100644
View file @
80178fc6
package
(
default_visibility
=
[
"//visibility:public"
])
cc_library
(
name
=
"disjoint_set_forest"
,
hdrs
=
[
"disjoint_set_forest.h"
],
deps
=
[
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"disjoint_set_forest_test"
,
size
=
"small"
,
srcs
=
[
"disjoint_set_forest_test.cc"
],
deps
=
[
":disjoint_set_forest"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"spanning_tree_iterator"
,
testonly
=
1
,
srcs
=
[
"spanning_tree_iterator.cc"
],
hdrs
=
[
"spanning_tree_iterator.h"
],
deps
=
[
"//syntaxnet:base"
,
],
)
cc_test
(
name
=
"spanning_tree_iterator_test"
,
size
=
"small"
,
srcs
=
[
"spanning_tree_iterator_test.cc"
],
deps
=
[
":spanning_tree_iterator"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"mst_solver"
,
hdrs
=
[
"mst_solver.h"
],
deps
=
[
":disjoint_set_forest"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"mst_solver_test"
,
size
=
"small"
,
srcs
=
[
"mst_solver_test.cc"
],
deps
=
[
":mst_solver"
,
"//dragnn/core/test:generic"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_test
(
name
=
"mst_solver_random_comparison_test"
,
size
=
"small"
,
timeout
=
"long"
,
srcs
=
[
"mst_solver_random_comparison_test.cc"
],
tags
=
[
"manual"
,
# exclude from :all, since this is expensive
],
deps
=
[
":mst_solver"
,
":spanning_tree_iterator"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
load
(
"@org_tensorflow//tensorflow:tensorflow.bzl"
,
"tf_gen_op_libs"
,
"tf_gen_op_wrapper_py"
,
)
tf_gen_op_libs
(
op_lib_names
=
[
"mst_ops"
],
)
# Don't use this library directly; instead use "dragnn/python:mst_ops".
tf_gen_op_wrapper_py
(
name
=
"mst_ops"
,
visibility
=
[
"//dragnn/python:__pkg__"
],
deps
=
[
":mst_ops_op_lib"
],
)
cc_library
(
name
=
"mst_ops_cc"
,
srcs
=
[
"ops/mst_op_kernels.cc"
,
"ops/mst_ops.cc"
,
],
deps
=
[
":mst_solver"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:framework_headers_lib"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
research/syntaxnet/dragnn/mst/README.md
0 → 100644
View file @
80178fc6
Package for solving max-spanning-tree (MST) problems. The code here is intended
for NLP applications, but attempts to remain agnostic to particular NLP tasks
(such as dependency parsing).
research/syntaxnet/dragnn/mst/disjoint_set_forest.h
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_MST_DISJOINT_SET_FOREST_H_
#define DRAGNN_MST_DISJOINT_SET_FOREST_H_
#include <stddef.h>
#include <type_traits>
#include <vector>
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
// An implementation of the disjoint-set forest data structure. The universe of
// elements is the dense range of indices [0,n). Thread-compatible.
//
// By default, this uses the path compression and union by rank optimizations,
// achieving near-constant runtime on all operations. However, the user may
// disable the union by rank optimization, which allows the user to control how
// roots are selected when a union occurs. When union by rank is disabled, the
// runtime of all operations increases to O(log n) amortized.
//
// Template args:
// Index: An unsigned integral type wide enough to hold n.
// kUseUnionByRank: Whether to use the union by rank optimization.
template
<
class
Index
,
bool
kUseUnionByRank
=
true
>
class
DisjointSetForest
{
public:
static_assert
(
std
::
is_integral
<
Index
>::
value
,
"Index must be integral"
);
static_assert
(
!
std
::
is_signed
<
Index
>::
value
,
"Index must be unsigned"
);
using
IndexType
=
Index
;
// Creates an empty forest.
DisjointSetForest
()
=
default
;
// Initializes this to hold the elements [0,|size|), each initially in its own
// singleton set. Replaces existing state, if any.
void
Init
(
Index
size
);
// Returns the root of the set containing |element|, which uniquely identifies
// the set. Note that the root of a set may change as the set is merged with
// other sets; do not cache the return value of FindRoot(e) across calls to
// Union() or UnionOfRoots() that could merge the set containing e.
Index
FindRoot
(
Index
element
);
// For convenience, returns true if |element1| and |element2| are in the same
// set. When performing a large batch of queries it may be more efficient to
// cache the value of FindRoot(), modulo caveats regarding caching above.
bool
SameSet
(
Index
element1
,
Index
element2
);
// Merges the sets rooted at |root1| and |root2|, which must be the roots of
// their respective sets. Either |root1| or |root2| will be the root of the
// merged set. If |kUseUnionByRank| is true, then it is unspecified whether
// |root1| or |root2| will be the root; otherwise, |root2| will be the root.
void
UnionOfRoots
(
Index
root1
,
Index
root2
);
// As above, but for convenience finds the root of |element1| and |element2|.
void
Union
(
Index
element1
,
Index
element2
);
// The number of elements in this.
Index
size
()
const
{
return
size_
;
}
private:
// The number of elements in the universe underlying the sets.
Index
size_
=
0
;
// The parent of each element, where self-loops are roots.
std
::
vector
<
Index
>
parents_
;
// The rank of each element, for the union by rank optimization. Only used if
// |kUseUnionByRank| is true.
std
::
vector
<
Index
>
ranks_
;
};
// Implementation details below.
template
<
class
Index
,
bool
kUseUnionByRank
>
void
DisjointSetForest
<
Index
,
kUseUnionByRank
>::
Init
(
Index
size
)
{
size_
=
size
;
parents_
.
resize
(
size_
);
if
(
kUseUnionByRank
)
ranks_
.
resize
(
size_
);
// Create singleton sets.
for
(
Index
i
=
0
;
i
<
size_
;
++
i
)
{
parents_
[
i
]
=
i
;
if
(
kUseUnionByRank
)
ranks_
[
i
]
=
0
;
}
}
template
<
class
Index
,
bool
kUseUnionByRank
>
Index
DisjointSetForest
<
Index
,
kUseUnionByRank
>::
FindRoot
(
Index
element
)
{
DCHECK_LT
(
element
,
size
());
Index
*
const
__restrict
parents
=
parents_
.
data
();
// Walk up to the root of the |element|. Unroll the first two comparisons
// because path compression ensures most FindRoot() calls end there. In
// addition, if a root is found within the first two comparisons, then the
// path compression updates can be skipped.
Index
current
=
element
;
Index
parent
=
parents
[
current
];
if
(
current
==
parent
)
return
current
;
// |element| is a root
current
=
parent
;
parent
=
parents
[
current
];
if
(
current
==
parent
)
return
current
;
// |element| is the child of a root
do
{
// otherwise, continue upwards until root
current
=
parent
;
parent
=
parents
[
current
];
}
while
(
current
!=
parent
);
const
Index
root
=
current
;
// Apply path compression on the traversed nodes.
current
=
element
;
parent
=
parents
[
current
];
// not root, thanks to unrolling above
do
{
parents
[
current
]
=
root
;
current
=
parent
;
parent
=
parents
[
current
];
}
while
(
parent
!=
root
);
return
root
;
}
template
<
class
Index
,
bool
kUseUnionByRank
>
bool
DisjointSetForest
<
Index
,
kUseUnionByRank
>::
SameSet
(
Index
element1
,
Index
element2
)
{
return
FindRoot
(
element1
)
==
FindRoot
(
element2
);
}
template
<
class
Index
,
bool
kUseUnionByRank
>
void
DisjointSetForest
<
Index
,
kUseUnionByRank
>::
UnionOfRoots
(
Index
root1
,
Index
root2
)
{
DCHECK_LT
(
root1
,
size
());
DCHECK_LT
(
root2
,
size
());
DCHECK_EQ
(
root1
,
parents_
[
root1
]);
DCHECK_EQ
(
root2
,
parents_
[
root2
]);
if
(
root1
==
root2
)
return
;
// already merged
Index
*
const
__restrict
parents
=
parents_
.
data
();
if
(
kUseUnionByRank
)
{
// Attach the lesser-rank root to the higher-rank root.
Index
*
const
__restrict
ranks
=
ranks_
.
data
();
const
Index
rank1
=
ranks
[
root1
];
const
Index
rank2
=
ranks
[
root2
];
if
(
rank2
<
rank1
)
{
parents
[
root2
]
=
root1
;
}
else
if
(
rank1
<
rank2
)
{
parents
[
root1
]
=
root2
;
}
else
{
// Equal ranks; choose one arbitrarily and promote its rank.
parents
[
root1
]
=
root2
;
ranks
[
root2
]
=
rank2
+
1
;
}
}
else
{
// Always make |root2| the root of the merged set.
parents
[
root1
]
=
root2
;
}
}
template
<
class
Index
,
bool
kUseUnionByRank
>
void
DisjointSetForest
<
Index
,
kUseUnionByRank
>::
Union
(
Index
element1
,
Index
element2
)
{
UnionOfRoots
(
FindRoot
(
element1
),
FindRoot
(
element2
));
}
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_MST_DISJOINT_SET_FOREST_H_
research/syntaxnet/dragnn/mst/disjoint_set_forest_test.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/disjoint_set_forest.h"
#include <stddef.h>
#include <set>
#include <utility>
#include <vector>
#include "syntaxnet/base.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
{
// Testing rig.
//
// Template args:
// Forest: An instantiation of the DisjointSetForest<> template.
template
<
class
Forest
>
class
DisjointSetForestTest
:
public
::
testing
::
Test
{
protected:
using
Index
=
typename
Forest
::
IndexType
;
// Expects that the |expected_sets| and |forest| match.
void
ExpectSets
(
const
std
::
set
<
std
::
set
<
Index
>>
&
expected_sets
,
Forest
*
forest
)
{
std
::
set
<
std
::
pair
<
Index
,
Index
>>
expected_pairs
;
for
(
const
auto
&
expected_set
:
expected_sets
)
{
for
(
auto
it
=
expected_set
.
begin
();
it
!=
expected_set
.
end
();
++
it
)
{
for
(
auto
jt
=
expected_set
.
begin
();
jt
!=
expected_set
.
end
();
++
jt
)
{
expected_pairs
.
emplace
(
*
it
,
*
jt
);
}
}
}
for
(
Index
lhs
=
0
;
lhs
<
forest
->
size
();
++
lhs
)
{
for
(
Index
rhs
=
0
;
rhs
<
forest
->
size
();
++
rhs
)
{
if
(
expected_pairs
.
find
({
lhs
,
rhs
})
!=
expected_pairs
.
end
())
{
EXPECT_EQ
(
forest
->
FindRoot
(
lhs
),
forest
->
FindRoot
(
rhs
));
EXPECT_TRUE
(
forest
->
SameSet
(
lhs
,
rhs
));
}
else
{
EXPECT_NE
(
forest
->
FindRoot
(
lhs
),
forest
->
FindRoot
(
rhs
));
EXPECT_FALSE
(
forest
->
SameSet
(
lhs
,
rhs
));
}
}
}
}
};
using
Forests
=
::
testing
::
Types
<
DisjointSetForest
<
uint8
,
false
>
,
DisjointSetForest
<
uint8
,
true
>
,
DisjointSetForest
<
uint16
,
false
>
,
DisjointSetForest
<
uint16
,
true
>
,
DisjointSetForest
<
uint32
,
false
>
,
DisjointSetForest
<
uint32
,
true
>
,
DisjointSetForest
<
uint64
,
false
>
,
DisjointSetForest
<
uint64
,
true
>>
;
TYPED_TEST_CASE
(
DisjointSetForestTest
,
Forests
);
TYPED_TEST
(
DisjointSetForestTest
,
DefaultEmpty
)
{
TypeParam
forest
;
EXPECT_EQ
(
0
,
forest
.
size
());
}
TYPED_TEST
(
DisjointSetForestTest
,
InitEmpty
)
{
TypeParam
forest
;
forest
.
Init
(
0
);
EXPECT_EQ
(
0
,
forest
.
size
());
}
TYPED_TEST
(
DisjointSetForestTest
,
Populated
)
{
TypeParam
forest
;
forest
.
Init
(
5
);
EXPECT_EQ
(
5
,
forest
.
size
());
this
->
ExpectSets
({{
0
},
{
1
},
{
2
},
{
3
},
{
4
}},
&
forest
);
forest
.
UnionOfRoots
(
1
,
2
);
this
->
ExpectSets
({{
0
},
{
1
,
2
},
{
3
},
{
4
}},
&
forest
);
forest
.
Union
(
1
,
2
);
this
->
ExpectSets
({{
0
},
{
1
,
2
},
{
3
},
{
4
}},
&
forest
);
forest
.
UnionOfRoots
(
0
,
4
);
this
->
ExpectSets
({{
0
,
4
},
{
1
,
2
},
{
3
}},
&
forest
);
forest
.
Union
(
3
,
4
);
this
->
ExpectSets
({{
0
,
3
,
4
},
{
1
,
2
}},
&
forest
);
forest
.
Union
(
0
,
3
);
this
->
ExpectSets
({{
0
,
3
,
4
},
{
1
,
2
}},
&
forest
);
forest
.
Union
(
2
,
0
);
this
->
ExpectSets
({{
0
,
1
,
2
,
3
,
4
}},
&
forest
);
forest
.
Union
(
1
,
3
);
this
->
ExpectSets
({{
0
,
1
,
2
,
3
,
4
}},
&
forest
);
}
// Testing rig for checking that when union by rank is disabled, the root of a
// merged set can be controlled.
class
DisjointSetForestNoUnionByRankTest
:
public
::
testing
::
Test
{
protected:
using
Forest
=
DisjointSetForest
<
uint32
,
false
>
;
// Expects that the roots of the |forest| match |expected_roots|.
void
ExpectRoots
(
const
std
::
vector
<
uint32
>
&
expected_roots
,
Forest
*
forest
)
{
ASSERT_EQ
(
expected_roots
.
size
(),
forest
->
size
());
for
(
uint32
i
=
0
;
i
<
forest
->
size
();
++
i
)
{
EXPECT_EQ
(
expected_roots
[
i
],
forest
->
FindRoot
(
i
));
}
}
};
TEST_F
(
DisjointSetForestNoUnionByRankTest
,
ManuallySpecifyRoot
)
{
Forest
forest
;
forest
.
Init
(
5
);
ExpectRoots
({
0
,
1
,
2
,
3
,
4
},
&
forest
);
forest
.
UnionOfRoots
(
0
,
1
);
// 1 is the root
ExpectRoots
({
1
,
1
,
2
,
3
,
4
},
&
forest
);
forest
.
Union
(
4
,
3
);
// 3 is the root
ExpectRoots
({
1
,
1
,
2
,
3
,
3
},
&
forest
);
forest
.
Union
(
0
,
2
);
// 2 is the root
ExpectRoots
({
2
,
2
,
2
,
3
,
3
},
&
forest
);
forest
.
Union
(
3
,
3
);
// no effect
ExpectRoots
({
2
,
2
,
2
,
3
,
3
},
&
forest
);
forest
.
Union
(
4
,
0
);
// 2 is the root
ExpectRoots
({
2
,
2
,
2
,
2
,
2
},
&
forest
);
}
}
// namespace
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/mst/mst_solver.h
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_MST_MST_SOLVER_H_
#define DRAGNN_MST_MST_SOLVER_H_
#include <stddef.h>
#include <algorithm>
#include <cmath>
#include <limits>
#include <type_traits>
#include <utility>
#include <vector>
#include "dragnn/mst/disjoint_set_forest.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
// Maximum spanning tree solver for directed graphs. Thread-compatible.
//
// The solver operates on a digraph of n nodes and m arcs and outputs a maximum
// spanning tree rooted at any node. Scores can be associated with arcs and
// root selections, and the score of a tree is the sum of the relevant arc and
// root-selection scores.
//
// The implementation is based on:
//
// R.E. Tarjan. 1977. Finding Optimum Branchings. Networks 7(1), pp. 25-35.
// [In particular, see Section 4 "a modification for dense graphs"]
//
// which itself is an improvement of the Chu-Liu-Edmonds algorithm. Note also
// the correction in:
//
// P.M. Camerini, L. Fratta, F. Maffioli. 1979. A Note on Finding Optimum
// Branchings. Networks 9(4), pp. 309-312.
//
// The solver runs in O(n^2) time, which is optimal for dense digraphs but slow
// for sparse digraphs where O(m + n log n) can be achieved. The solver uses
// O(n^2) space to store the digraph, which is also optimal for dense digraphs.
//
// Although this algorithm has an inferior asymptotic runtime on sparse graphs,
// it avoids high-constant-overhead data structures like Fibonacci heaps, which
// are required in the asymptotically faster algorithms. Therefore, this solver
// may still be competitive on small sparse graphs.
//
// TODO(googleuser): If we start running on large sparse graphs, implement the
// following, which runs in O(m + n log n):
//
// H.N. Gabow, Z. Galil, T. Spencer, and R.E. Tarjan. 1986. Efficient
// algorithms for finding minimum spanning trees in undirected and directed
// graphs. Combinatorica, 6(2), pp. 109-122.
//
// Template args:
// Index: An unsigned integral type wide enough to hold 2n.
// Score: A signed arithmetic (integral or floating-point) type.
template
<
class
Index
,
class
Score
>
class
MstSolver
{
public:
static_assert
(
std
::
is_integral
<
Index
>::
value
,
"Index must be integral"
);
static_assert
(
!
std
::
is_signed
<
Index
>::
value
,
"Index must be unsigned"
);
static_assert
(
std
::
is_arithmetic
<
Score
>::
value
,
"Score must be arithmetic"
);
static_assert
(
std
::
is_signed
<
Score
>::
value
,
"Score must be signed"
);
using
IndexType
=
Index
;
using
ScoreType
=
Score
;
// Creates an empty solver. Call Init() before use.
MstSolver
()
=
default
;
// Initializes this for a digraph with |num_nodes| nodes, or returns non-OK on
// error. Discards existing state; call AddArc() and AddRoot() to add arcs
// and root selections. If |forest| is true, then this solves for a maximum
// spanning forest (i.e., a set of disjoint trees that span the digraph).
tensorflow
::
Status
Init
(
bool
forest
,
Index
num_nodes
);
// Adds an arc from the |source| node to the |target| node with the |score|.
// The |source| and |target| must be distinct node indices in [0,n), and the
// |score| must be finite. Calling this multiple times on the same |source|
// and |target| overwrites the score instead of adding parallel arcs.
void
AddArc
(
Index
source
,
Index
target
,
Score
score
);
// As above, but adds a root selection for the |root| node with the |score|.
void
AddRoot
(
Index
root
,
Score
score
);
// Populates |argmax| with the maximum directed spanning tree of the current
// digraph, or returns non-OK on error. The |argmax| array must contain at
// least n elements. On success, argmax[t] is the source of the arc directed
// into t, or t itself if t is a root.
//
// NB: If multiple spanning trees achieve the maximum score, |argmax| will be
// set to one of the maximal trees, but it is unspecified which one.
tensorflow
::
Status
Solve
(
tensorflow
::
gtl
::
MutableArraySlice
<
Index
>
argmax
);
private:
// Implementation notes:
//
// The solver does not operate on the "original" digraph as specified by the
// user, but a "transformed" digraph that differs as follows:
//
// * The transformed digraph adds an "artificial root" node at index 0 and
// offsets all original node indices by +1 to make room. For each root
// selection, the artificial root has one outbound arc directed into the
// candidate root that carries the root-selection score. The artificial
// root has no inbound arcs.
//
// * When solving for a spanning tree (i.e., when |forest_| is false), the
// outbound arcs of the artificial root are penalized to ensure that the
// artificial root has exactly one child.
//
// In the remainder of this file, all mentions of nodes, arcs, etc., refer to
// the transformed digraph unless otherwise specified.
//
// The algorithm is divided into two phases, the "contraction phase" and the
// "expansion phase". The contraction phase finds the arcs that make up the
// maximum spanning tree by applying a series of "contractions" which further
// modify the digraph. The expansion phase "expands" these modifications and
// recovers the maximum spanning tree in the original digraph.
//
// During the contraction phase, the algorithm selects the best inbound arc
// for each node. These arcs can form cycles, which are "contracted" by
// removing the cycle nodes and replacing them with a new contracted node.
// Since each contraction removes 2 or more cycle nodes and adds 1 contracted
// node, at most n-1 contractions will occur. (The digraph initially contains
// n+1 nodes, but one is the artificial root, which cannot form a cycle).
//
// When contracting a cycle, nodes are not explicitly removed and replaced.
// Instead, a contracted node is appended to the digraph and the cycle nodes
// are remapped to the contracted node, which implicitly removes and replaces
// the cycle. As a result, each contraction actually increases the size of
// the digraph, up to a maximum of 2n nodes. One advantage of adding and
// remapping nodes is that it is convenient to recover the argmax spanning
// tree during the expansion phase.
//
// Note that contractions can be nested, because the best inbound arc for a
// contracted node may itelf form a cycle. During the expansion phase, the
// algorithm picks a root of the hierarchy of contracted nodes, breaks the
// cycle it represents, and repeats until all cycles are broken.
// Constants, as enums to avoid the need for static variable definitions.
enum
Constants
:
Index
{
// An index reserved for "null" values.
kNullIndex
=
std
::
numeric_limits
<
Index
>::
max
(),
};
// A possibly-nonexistent arc in the digraph.
struct
Arc
{
// Creates a nonexistent arc.
Arc
()
=
default
;
// Returns true if this arc exists.
bool
Exists
()
const
{
return
target
!=
0
;
}
// Returns true if this is a root-selection arc.
bool
IsRoot
()
const
{
return
source
==
0
;
}
// Returns a string representation of this arc.
string
DebugString
()
const
{
if
(
!
Exists
())
return
"[null]"
;
if
(
IsRoot
())
{
return
tensorflow
::
strings
::
StrCat
(
"[*->"
,
target
,
"="
,
score
,
"]"
);
}
return
tensorflow
::
strings
::
StrCat
(
"["
,
source
,
"->"
,
target
,
"="
,
score
,
"]"
);
}
// Score of this arc.
Score
score
;
// Source of this arc in the initial digraph.
Index
source
;
// Target of this arc in the initial digraph, or 0 if this is nonexistent.
Index
target
=
0
;
};
// Returns the index, in |arcs_|, of the arc from |source| to |target|. The
// |source| must be one of the initial n+1 nodes.
size_t
ArcIndex
(
size_t
source
,
size_t
target
)
const
;
// Penalizes the root arc scores to ensure that this finds a tree, or does
// nothing if |forest_| is true. Must be called before ContractionPhase().
void
MaybePenalizeRootScoresForTree
();
// Returns the maximum inbound arc of the |node|, or null if there is none.
const
Arc
*
MaximumInboundArc
(
Index
node
)
const
;
// Merges the inbound arcs of the |cycle_node| into the inbound arcs of the
// |contracted_node|. Arcs are merged as follows:
// * If the source and target of the arc belong to the same strongly-connected
// component, it is ignored.
// * If exactly one of the nodes had an arc from some source, then on exit the
// |contracted_node| has that arc.
// * If both of the nodes had an arc from the same source, then on exit the
// |contracted_node| has the better-scoring arc.
// The |score_offset| is added to the arc scores of the |cycle_node| before
// they are merged into the |contracted_node|.
void
MergeInboundArcs
(
Index
cycle_node
,
Score
score_offset
,
Index
contracted_node
);
// Contracts the cycle in |argmax_arcs_| that contains the |node|.
void
ContractCycle
(
Index
node
);
// Runs the contraction phase of the solver, or returns non-OK on error. This
// phase finds the best inbound arc for each node, contracting cycles as they
// are formed. Stops when every node has selected an inbound arc and there
// are no cycles.
tensorflow
::
Status
ContractionPhase
();
// Runs the expansion phase of the solver, or returns non-OK on error. This
// phase expands each contracted node, breaks cycles, and populates |argmax|
// with the maximum spanning tree.
tensorflow
::
Status
ExpansionPhase
(
tensorflow
::
gtl
::
MutableArraySlice
<
Index
>
argmax
);
// If true, solve for a spanning forest instead of a spanning tree.
bool
forest_
=
false
;
// The number of nodes in the original digraph; i.e., n.
Index
num_original_nodes_
=
0
;
// The number of nodes in the initial digraph; i.e., n+1.
Index
num_initial_nodes_
=
0
;
// The maximum number of possible nodes in the digraph; i.e., 2n.
Index
num_possible_nodes_
=
0
;
// The number of nodes in the current digraph, which grows from n+1 to 2n.
Index
num_current_nodes_
=
0
;
// Column-major |num_initial_nodes_| x |num_current_nodes_| matrix of arcs,
// where rows and columns correspond to source and target nodes. Columns are
// added as cycles are contracted into new nodes.
//
// TODO(googleuser): It is possible to squeeze the nonexistent arcs out of each
// column and run the algorithm with each column being a sorted list (sorted
// by source node). This is in fact the suggested representation in Tarjan
// (1977). This won't improve the asymptotic runtime but still might improve
// speed in practice. I haven't done this because it adds complexity versus
// checking Arc::Exists() in a few loops. Try this out when we can benchmark
// this on real data.
std
::
vector
<
Arc
>
arcs_
;
// Disjoint-set forests tracking the weakly-connected and strongly-connected
// components of the initial digraph, based on the arcs in |argmax_arcs_|.
// Weakly-connected components are used to detect cycles; strongly-connected
// components are used to detect self-loops.
DisjointSetForest
<
Index
>
weak_components_
;
DisjointSetForest
<
Index
>
strong_components_
;
// A disjoint-set forest that maps each node to the top-most contracted node
// that contains it. Nodes that have not been contracted map to themselves.
// NB: This disjoint-set forest does not use union by rank so we can control
// the outcome of a set union. There will only be O(n) operations on this
// instance, so the increased O(log n) cost of each operation is acceptable.
DisjointSetForest
<
Index
,
false
>
contracted_nodes_
;
// An array that represents the history of cycle contractions, as follows:
// * If contracted_into_[t] is |kNullIndex|, then t is deleted.
// * If contracted_into_[t] is 0, then t is a "root" contracted node; i.e., t
// has not been contracted into another node.
// * Otherwise, contracted_into_[t] is the node into which t was contracted.
std
::
vector
<
Index
>
contracted_into_
;
// The maximum inbound arc for each node. The first element is null because
// the artificial root has no inbound arcs.
std
::
vector
<
const
Arc
*>
argmax_arcs_
;
// Workspace for ContractCycle(), which records the nodes and arcs in the
// cycle being contracted.
std
::
vector
<
std
::
pair
<
Index
,
const
Arc
*>>
cycle_
;
};
// Implementation details below.
template
<
class
Index
,
class
Score
>
tensorflow
::
Status
MstSolver
<
Index
,
Score
>::
Init
(
bool
forest
,
Index
num_nodes
)
{
if
(
num_nodes
<=
0
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Non-positive number of nodes: "
,
num_nodes
);
}
// Upcast to size_t to avoid overflow.
if
(
2
*
static_cast
<
size_t
>
(
num_nodes
)
>=
static_cast
<
size_t
>
(
kNullIndex
))
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Too many nodes: "
,
num_nodes
);
}
forest_
=
forest
;
num_original_nodes_
=
num_nodes
;
num_initial_nodes_
=
num_original_nodes_
+
1
;
num_possible_nodes_
=
2
*
num_original_nodes_
;
num_current_nodes_
=
num_initial_nodes_
;
// Allocate the full n+1 x 2n matrix, but start with a n+1 x n+1 prefix.
const
size_t
num_initial_arcs
=
static_cast
<
size_t
>
(
num_initial_nodes_
)
*
static_cast
<
size_t
>
(
num_initial_nodes_
);
const
size_t
num_possible_arcs
=
static_cast
<
size_t
>
(
num_initial_nodes_
)
*
static_cast
<
size_t
>
(
num_possible_nodes_
);
arcs_
.
reserve
(
num_possible_arcs
);
arcs_
.
assign
(
num_initial_arcs
,
{});
weak_components_
.
Init
(
num_initial_nodes_
);
strong_components_
.
Init
(
num_initial_nodes_
);
contracted_nodes_
.
Init
(
num_possible_nodes_
);
contracted_into_
.
assign
(
num_possible_nodes_
,
0
);
argmax_arcs_
.
assign
(
num_possible_nodes_
,
nullptr
);
// This doesn't need to be cleared now; it will be cleared before use.
cycle_
.
reserve
(
num_original_nodes_
);
return
tensorflow
::
Status
::
OK
();
}
template
<
class
Index
,
class
Score
>
void
MstSolver
<
Index
,
Score
>::
AddArc
(
Index
source
,
Index
target
,
Score
score
)
{
DCHECK_NE
(
source
,
target
);
DCHECK
(
std
::
isfinite
(
score
));
Arc
&
arc
=
arcs_
[
ArcIndex
(
source
+
1
,
target
+
1
)];
arc
.
score
=
score
;
arc
.
source
=
source
+
1
;
arc
.
target
=
target
+
1
;
}
template
<
class
Index
,
class
Score
>
void
MstSolver
<
Index
,
Score
>::
AddRoot
(
Index
root
,
Score
score
)
{
DCHECK
(
std
::
isfinite
(
score
));
Arc
&
arc
=
arcs_
[
ArcIndex
(
0
,
root
+
1
)];
arc
.
score
=
score
;
arc
.
source
=
0
;
arc
.
target
=
root
+
1
;
}
template
<
class
Index
,
class
Score
>
tensorflow
::
Status
MstSolver
<
Index
,
Score
>::
Solve
(
tensorflow
::
gtl
::
MutableArraySlice
<
Index
>
argmax
)
{
MaybePenalizeRootScoresForTree
();
TF_RETURN_IF_ERROR
(
ContractionPhase
());
TF_RETURN_IF_ERROR
(
ExpansionPhase
(
argmax
));
return
tensorflow
::
Status
::
OK
();
}
template
<
class
Index
,
class
Score
>
inline
size_t
MstSolver
<
Index
,
Score
>::
ArcIndex
(
size_t
source
,
size_t
target
)
const
{
DCHECK_LT
(
source
,
num_initial_nodes_
);
DCHECK_LT
(
target
,
num_current_nodes_
);
return
source
+
target
*
static_cast
<
size_t
>
(
num_initial_nodes_
);
}
template
<
class
Index
,
class
Score
>
void
MstSolver
<
Index
,
Score
>::
MaybePenalizeRootScoresForTree
()
{
if
(
forest_
)
return
;
DCHECK_EQ
(
num_current_nodes_
,
num_initial_nodes_
)
<<
"Root penalties must be applied before starting the algorithm."
;
// Find the minimum and maximum arc scores. These allow us to bound the range
// of possible tree scores.
Score
max_score
=
std
::
numeric_limits
<
Score
>::
lowest
();
Score
min_score
=
std
::
numeric_limits
<
Score
>::
max
();
for
(
const
Arc
&
arc
:
arcs_
)
{
if
(
!
arc
.
Exists
())
continue
;
max_score
=
std
::
max
(
max_score
,
arc
.
score
);
min_score
=
std
::
min
(
min_score
,
arc
.
score
);
}
// Nothing to do, no existing arcs.
if
(
max_score
<
min_score
)
return
;
// A spanning tree or forest contains n arcs. The penalty below ensures that
// every structure with one root has a higher score than every structure with
// two roots, and so on.
const
Score
root_penalty
=
1
+
num_initial_nodes_
*
(
max_score
-
min_score
);
for
(
Index
root
=
1
;
root
<
num_initial_nodes_
;
++
root
)
{
Arc
&
arc
=
arcs_
[
ArcIndex
(
0
,
root
)];
if
(
!
arc
.
Exists
())
continue
;
arc
.
score
-=
root_penalty
;
}
}
template
<
class
Index
,
class
Score
>
const
typename
MstSolver
<
Index
,
Score
>::
Arc
*
MstSolver
<
Index
,
Score
>::
MaximumInboundArc
(
Index
node
)
const
{
const
Arc
*
__restrict
arc
=
&
arcs_
[
ArcIndex
(
0
,
node
)];
const
Arc
*
arc_end
=
arc
+
num_initial_nodes_
;
Score
max_score
=
std
::
numeric_limits
<
Score
>::
lowest
();
const
Arc
*
argmax_arc
=
nullptr
;
for
(;
arc
<
arc_end
;
++
arc
)
{
if
(
!
arc
->
Exists
())
continue
;
const
Score
score
=
arc
->
score
;
if
(
max_score
<=
score
)
{
max_score
=
score
;
argmax_arc
=
arc
;
}
}
return
argmax_arc
;
}
template
<
class
Index
,
class
Score
>
void
MstSolver
<
Index
,
Score
>::
MergeInboundArcs
(
Index
cycle_node
,
Score
score_offset
,
Index
contracted_node
)
{
const
Arc
*
__restrict
cycle_arc
=
&
arcs_
[
ArcIndex
(
0
,
cycle_node
)];
const
Arc
*
cycle_arc_end
=
cycle_arc
+
num_initial_nodes_
;
Arc
*
__restrict
contracted_arc
=
&
arcs_
[
ArcIndex
(
0
,
contracted_node
)];
for
(;
cycle_arc
<
cycle_arc_end
;
++
cycle_arc
,
++
contracted_arc
)
{
if
(
!
cycle_arc
->
Exists
())
continue
;
// nothing to merge
// Skip self-loops; they are useless because they cannot be used to break
// the cycle represented by the |contracted_node|.
if
(
strong_components_
.
SameSet
(
cycle_arc
->
source
,
cycle_arc
->
target
))
{
continue
;
}
// Merge the |cycle_arc| into the |contracted_arc|.
const
Score
cycle_score
=
cycle_arc
->
score
+
score_offset
;
if
(
!
contracted_arc
->
Exists
()
||
contracted_arc
->
score
<
cycle_score
)
{
contracted_arc
->
score
=
cycle_score
;
contracted_arc
->
source
=
cycle_arc
->
source
;
contracted_arc
->
target
=
cycle_arc
->
target
;
}
}
}
template
<
class
Index
,
class
Score
>
void
MstSolver
<
Index
,
Score
>::
ContractCycle
(
Index
node
)
{
// Append a new node for the contracted cycle.
const
Index
contracted_node
=
num_current_nodes_
++
;
DCHECK_LE
(
num_current_nodes_
,
num_possible_nodes_
);
arcs_
.
resize
(
arcs_
.
size
()
+
num_initial_nodes_
);
// We make two passes through the cycle. The first pass updates everything
// except the |arcs_|, and the second pass updates the |arcs_|. The |arcs_|
// must be updated in a second pass because MergeInboundArcs() requires that
// the |strong_components_| are updated with the newly-contracted cycle.
cycle_
.
clear
();
Index
cycle_node
=
node
;
do
{
// Gather the nodes and arcs in |cycle_| for the second pass.
const
Arc
*
cycle_arc
=
argmax_arcs_
[
cycle_node
];
DCHECK
(
!
cycle_arc
->
IsRoot
())
<<
cycle_arc
->
DebugString
();
cycle_
.
emplace_back
(
cycle_node
,
cycle_arc
);
// Mark the cycle nodes as members of a strongly-connected component.
strong_components_
.
Union
(
cycle_arc
->
source
,
cycle_arc
->
target
);
// Mark the cycle nodes as members of the new contracted node. Juggling is
// required because |contracted_nodes_| also determines the next cycle node.
const
Index
next_node
=
contracted_nodes_
.
FindRoot
(
cycle_arc
->
source
);
contracted_nodes_
.
UnionOfRoots
(
cycle_node
,
contracted_node
);
contracted_into_
[
cycle_node
]
=
contracted_node
;
cycle_node
=
next_node
;
// When the cycle repeats, |cycle_node| will be equal to |contracted_node|,
// not |node|, because the first iteration of this loop mapped |node| to
// |contracted_node| in |contracted_nodes_|.
}
while
(
cycle_node
!=
contracted_node
);
// Merge the inbound arcs of each cycle node into the |contracted_node|.
for
(
const
auto
&
node_and_arc
:
cycle_
)
{
// Set the |score_offset| to the cost of breaking the cycle by replacing the
// arc currently directed into the |cycle_node|.
const
Index
cycle_node
=
node_and_arc
.
first
;
const
Score
score_offset
=
-
node_and_arc
.
second
->
score
;
MergeInboundArcs
(
cycle_node
,
score_offset
,
contracted_node
);
}
}
template
<
class
Index
,
class
Score
>
tensorflow
::
Status
MstSolver
<
Index
,
Score
>::
ContractionPhase
()
{
// Skip the artificial root since it has no inbound arcs.
for
(
Index
target
=
1
;
target
<
num_current_nodes_
;
++
target
)
{
// Find the maximum inbound arc for the current |target|, if any.
const
Arc
*
arc
=
MaximumInboundArc
(
target
);
if
(
arc
==
nullptr
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Infeasible digraph"
);
}
argmax_arcs_
[
target
]
=
arc
;
// The articifial root cannot be part of a cycle, so we do not need to check
// for cycles or even update its membership in the connected components.
if
(
arc
->
IsRoot
())
continue
;
// Since every node has at most one selected inbound arc, cycles can be
// detected using weakly-connected components.
const
Index
source_component
=
weak_components_
.
FindRoot
(
arc
->
source
);
const
Index
target_component
=
weak_components_
.
FindRoot
(
arc
->
target
);
if
(
source_component
==
target_component
)
{
// Cycle detected; contract it into a new node.
ContractCycle
(
target
);
}
else
{
// No cycles, just update the weakly-connected components.
weak_components_
.
UnionOfRoots
(
source_component
,
target_component
);
}
}
return
tensorflow
::
Status
::
OK
();
}
template
<
class
Index
,
class
Score
>
tensorflow
::
Status
MstSolver
<
Index
,
Score
>::
ExpansionPhase
(
tensorflow
::
gtl
::
MutableArraySlice
<
Index
>
argmax
)
{
if
(
argmax
.
size
()
<
num_original_nodes_
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"Argmax array too small: "
,
num_original_nodes_
,
" elements required, but got "
,
argmax
.
size
());
}
// Select and expand a root contracted node until no contracted nodes remain.
// Thanks to the (topological) order in which contracted nodes are appended,
// root contracted nodes are easily enumerated using a backward scan. After
// this loop, entries [1,n] of |argmax_arcs_| provide the arcs of the maximum
// spanning tree.
for
(
Index
i
=
num_current_nodes_
-
1
;
i
>=
num_initial_nodes_
;
--
i
)
{
if
(
contracted_into_
[
i
]
==
kNullIndex
)
continue
;
// already deleted
const
Index
root
=
i
;
// if not deleted, must be a root due to toposorting
// Copy the cycle-breaking arc to its specified target.
const
Arc
*
arc
=
argmax_arcs_
[
root
];
argmax_arcs_
[
arc
->
target
]
=
arc
;
// The |arc| not only breaks the cycle associated with the |root|, but also
// breaks every nested cycle between the |root| and the target of the |arc|.
// Delete the contracted nodes corresponding to all broken cycles.
Index
node
=
contracted_into_
[
arc
->
target
];
while
(
node
!=
kNullIndex
&&
node
!=
root
)
{
const
Index
parent
=
contracted_into_
[
node
];
contracted_into_
[
node
]
=
kNullIndex
;
node
=
parent
;
}
}
// Copy the spanning tree from |argmax_arcs_| to |argmax|. Also count roots
// for validation below.
Index
num_roots
=
0
;
for
(
Index
target
=
0
;
target
<
num_original_nodes_
;
++
target
)
{
const
Arc
&
arc
=
*
argmax_arcs_
[
target
+
1
];
DCHECK_EQ
(
arc
.
target
,
target
+
1
)
<<
arc
.
DebugString
();
if
(
arc
.
IsRoot
())
{
++
num_roots
;
argmax
[
target
]
=
target
;
}
else
{
argmax
[
target
]
=
arc
.
source
-
1
;
}
}
DCHECK_GE
(
num_roots
,
1
);
// Even when |forest_| is false, |num_roots| can still be more than 1. While
// the root score penalty discourages structures with multiple root arcs, it
// is not a hard constraint. For example, if the original digraph contained
// one root selection per node and no other arcs, the solver would incorrectly
// produce an all-root structure in spite of the root score penalty. As this
// example illustrates, however, |num_roots| will be more than 1 if and only
// if the original digraph is infeasible for trees.
if
(
!
forest_
&&
num_roots
!=
1
)
{
return
tensorflow
::
errors
::
FailedPrecondition
(
"Infeasible digraph"
);
}
return
tensorflow
::
Status
::
OK
();
}
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_MST_MST_SOLVER_H_
research/syntaxnet/dragnn/mst/mst_solver_random_comparison_test.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/mst_solver.h"
#include <time.h>
#include <random>
#include <set>
#include <vector>
#include "dragnn/mst/spanning_tree_iterator.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
{
using
::
testing
::
Contains
;
// Returns the random seed, or 0 for a weak random seed.
int64
GetSeed
()
{
return
1
;
// use a deterministic seed
}
// Returns the number of trials to run for each random comparison.
int64
GetNumTrials
()
{
return
3
;
}
// Testing rig. Runs a comparison between a brute-force MST solver and the
// MstSolver<> on random digraphs. When the first test parameter is true,
// solves for forests instead of trees. The second test parameter defines the
// size of the test digraph.
class
MstSolverRandomComparisonTest
:
public
::
testing
::
TestWithParam
<::
testing
::
tuple
<
bool
,
uint32
>>
{
protected:
// Use integer scores so score comparisons are exact.
using
Solver
=
MstSolver
<
uint32
,
int32
>
;
// An array providing a source node for each node. Roots are self-loops.
using
SourceList
=
SpanningTreeIterator
::
SourceList
;
// A row-major n x n matrix whose i,j entry gives the score of the arc from i
// to j, and whose i,i entry gives the score of selecting i as a root.
using
ScoreMatrix
=
std
::
vector
<
int32
>
;
// Returns true if this should be a forest.
bool
forest
()
const
{
return
::
testing
::
get
<
0
>
(
GetParam
());
}
// Returns the number of nodes for digraphs.
uint32
num_nodes
()
const
{
return
::
testing
::
get
<
1
>
(
GetParam
());
}
// Returns the score of the arcs in |sources| based on the |scores|.
int32
ScoreArcs
(
const
ScoreMatrix
&
scores
,
const
SourceList
&
sources
)
const
{
CHECK_EQ
(
num_nodes
()
*
num_nodes
(),
scores
.
size
());
int32
score
=
0
;
for
(
uint32
target
=
0
;
target
<
num_nodes
();
++
target
)
{
const
uint32
source
=
sources
[
target
];
score
+=
scores
[
target
+
source
*
num_nodes
()];
}
return
score
;
}
// Returns the score of the maximum spanning tree (or forest, if the first
// test parameter is true) of the dense digraph defined by the |scores|, and
// sets |argmax_trees| to contain all maximal trees.
int32
RunBruteForceMstSolver
(
const
ScoreMatrix
&
scores
,
std
::
set
<
SourceList
>
*
argmax_trees
)
{
CHECK_EQ
(
num_nodes
()
*
num_nodes
(),
scores
.
size
());
int32
max_score
;
argmax_trees
->
clear
();
iterator_
.
ForEachTree
(
num_nodes
(),
[
&
](
const
SourceList
&
sources
)
{
const
int32
score
=
ScoreArcs
(
scores
,
sources
);
if
(
argmax_trees
->
empty
()
||
max_score
<
score
)
{
max_score
=
score
;
argmax_trees
->
clear
();
argmax_trees
->
insert
(
sources
);
}
else
if
(
max_score
==
score
)
{
argmax_trees
->
insert
(
sources
);
}
});
return
max_score
;
}
// As above, but uses the |solver_| and extracts only one |argmax_tree|.
int32
RunMstSolver
(
const
ScoreMatrix
&
scores
,
SourceList
*
argmax_tree
)
{
CHECK_EQ
(
num_nodes
()
*
num_nodes
(),
scores
.
size
());
TF_CHECK_OK
(
solver_
.
Init
(
forest
(),
num_nodes
()));
// Add all roots and arcs.
for
(
uint32
source
=
0
;
source
<
num_nodes
();
++
source
)
{
for
(
uint32
target
=
0
;
target
<
num_nodes
();
++
target
)
{
const
int32
score
=
scores
[
target
+
source
*
num_nodes
()];
if
(
source
==
target
)
{
solver_
.
AddRoot
(
target
,
score
);
}
else
{
solver_
.
AddArc
(
source
,
target
,
score
);
}
}
}
// Solve for the max spanning tree.
argmax_tree
->
resize
(
num_nodes
());
TF_CHECK_OK
(
solver_
.
Solve
(
argmax_tree
));
return
ScoreArcs
(
scores
,
*
argmax_tree
);
}
// Returns a random ScoreMatrix spanning num_nodes() nodes.
ScoreMatrix
RandomScores
()
{
ScoreMatrix
scores
(
num_nodes
()
*
num_nodes
());
for
(
int32
&
value
:
scores
)
value
=
static_cast
<
int32
>
(
prng_
()
%
201
)
-
100
;
return
scores
;
}
// Runs a comparison between MstSolver and BruteForceMst on random digraphs of
// num_nodes() nodes, for the specified number of trials.
void
RunComparison
()
{
// Seed the PRNG, possibly non-deterministically. Log the seed value so the
// test results can be reproduced, even when the seed is non-deterministic.
uint32
seed
=
GetSeed
();
if
(
seed
==
0
)
seed
=
time
(
nullptr
);
prng_
.
seed
(
seed
);
LOG
(
INFO
)
<<
"seed = "
<<
seed
;
const
int
num_trials
=
GetNumTrials
();
for
(
int
trial
=
0
;
trial
<
num_trials
;
++
trial
)
{
const
ScoreMatrix
scores
=
RandomScores
();
std
::
set
<
SourceList
>
expected_argmax_trees
;
const
int32
expected_max_score
=
RunBruteForceMstSolver
(
scores
,
&
expected_argmax_trees
);
SourceList
actual_argmax_tree
;
const
int32
actual_max_score
=
RunMstSolver
(
scores
,
&
actual_argmax_tree
);
// In case of ties, MstSolver will find a maximal spanning tree, but we
// don't know which one.
EXPECT_EQ
(
expected_max_score
,
actual_max_score
);
ASSERT_THAT
(
expected_argmax_trees
,
Contains
(
actual_argmax_tree
));
}
}
// Tree iterator for brute-force solver.
SpanningTreeIterator
iterator_
{
forest
()};
// MstSolver<> instance used by the test. Reused across all MST invocations
// to exercise reuse.
Solver
solver_
;
// Pseudo-random number generator.
std
::
mt19937
prng_
;
};
INSTANTIATE_TEST_CASE_P
(
AllowForest
,
MstSolverRandomComparisonTest
,
::
testing
::
Combine
(
::
testing
::
Bool
(),
::
testing
::
Range
<
uint32
>
(
1
,
9
)));
TEST_P
(
MstSolverRandomComparisonTest
,
Comparison
)
{
RunComparison
();
}
}
// namespace
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/mst/mst_solver_test.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/mst_solver.h"
#include <limits>
#include <utility>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
{
using
::
testing
::
HasSubstr
;
// Testing rig.
//
// Template args:
// Solver: An instantiation of the MstSolver<> template.
template
<
class
Solver
>
class
MstSolverTest
:
public
::
testing
::
Test
{
protected:
using
Index
=
typename
Solver
::
IndexType
;
using
Score
=
typename
Solver
::
ScoreType
;
// Adds directed arcs for all |num_nodes| nodes to the |solver_| with the
// |score|.
void
AddAllArcs
(
Index
num_nodes
,
Score
score
)
{
for
(
Index
source
=
0
;
source
<
num_nodes
;
++
source
)
{
for
(
Index
target
=
0
;
target
<
num_nodes
;
++
target
)
{
if
(
source
==
target
)
continue
;
solver_
.
AddArc
(
source
,
target
,
score
);
}
}
}
// Adds root selections for all |num_nodes| nodes to the |solver_| with the
// |score|.
void
AddAllRoots
(
Index
num_nodes
,
Score
score
)
{
for
(
Index
root
=
0
;
root
<
num_nodes
;
++
root
)
{
solver_
.
AddRoot
(
root
,
score
);
}
}
// Runs the |solver_| using an argmax array of size |argmax_array_size| and
// expects it to fail with an error message that matches |error_substr|.
void
SolveAndExpectError
(
int
argmax_array_size
,
const
string
&
error_message_substr
)
{
std
::
vector
<
Index
>
argmax
(
argmax_array_size
);
EXPECT_THAT
(
solver_
.
Solve
(
&
argmax
),
test
::
IsErrorWithSubstr
(
error_message_substr
));
}
// As above, but expects success. Does not assert anything about the solution
// produced by the solver.
void
SolveAndExpectOk
(
int
argmax_array_size
)
{
std
::
vector
<
Index
>
argmax
(
argmax_array_size
);
TF_EXPECT_OK
(
solver_
.
Solve
(
&
argmax
));
}
// As above, but expects the solution to be |expected_argmax| and infers the
// argmax array size.
void
SolveAndExpectArgmax
(
const
std
::
vector
<
Index
>
&
expected_argmax
)
{
std
::
vector
<
Index
>
actual_argmax
(
expected_argmax
.
size
());
TF_ASSERT_OK
(
solver_
.
Solve
(
&
actual_argmax
));
EXPECT_EQ
(
expected_argmax
,
actual_argmax
);
}
// MstSolver<> instance used by the test. Reused across all MST problems in
// each test to exercise reuse.
Solver
solver_
;
};
using
Solvers
=
::
testing
::
Types
<
MstSolver
<
uint8
,
int16
>
,
MstSolver
<
uint16
,
int32
>
,
MstSolver
<
uint32
,
int64
>
,
MstSolver
<
uint16
,
float
>
,
MstSolver
<
uint32
,
double
>>
;
TYPED_TEST_CASE
(
MstSolverTest
,
Solvers
);
TYPED_TEST
(
MstSolverTest
,
FailIfNoNodes
)
{
for
(
const
bool
forest
:
{
false
,
true
})
{
EXPECT_THAT
(
this
->
solver_
.
Init
(
forest
,
0
),
test
::
IsErrorWithSubstr
(
"Non-positive number of nodes"
));
}
}
TYPED_TEST
(
MstSolverTest
,
FailIfTooManyNodes
)
{
// Set to a value that would overflow when doubled.
const
auto
kNumNodes
=
(
std
::
numeric_limits
<
typename
TypeParam
::
IndexType
>::
max
()
/
2
)
+
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
EXPECT_THAT
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
),
test
::
IsErrorWithSubstr
(
"Too many nodes"
));
}
}
TYPED_TEST
(
MstSolverTest
,
InfeasibleIfNoRootsNoArcs
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
SolveAndExpectError
(
kNumNodes
,
"Infeasible digraph"
);
}
}
TYPED_TEST
(
MstSolverTest
,
InfeasibleIfNoRootsAllArcs
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllArcs
(
kNumNodes
,
0
);
this
->
SolveAndExpectError
(
kNumNodes
,
"Infeasible digraph"
);
}
}
TYPED_TEST
(
MstSolverTest
,
FeasibleForForestOnlyIfAllRootsNoArcs
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
if
(
forest
)
{
this
->
SolveAndExpectOk
(
kNumNodes
);
// all roots is a valid forest
}
else
{
this
->
SolveAndExpectError
(
kNumNodes
,
"Infeasible digraph"
);
}
}
}
TYPED_TEST
(
MstSolverTest
,
FeasibleIfAllRootsAllArcs
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
this
->
SolveAndExpectOk
(
kNumNodes
);
}
}
TYPED_TEST
(
MstSolverTest
,
FailIfArgmaxArrayTooSmall
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
this
->
SolveAndExpectError
(
kNumNodes
-
1
,
// too small
"Argmax array too small"
);
}
}
TYPED_TEST
(
MstSolverTest
,
OkIfArgmaxArrayTooLarge
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
this
->
SolveAndExpectOk
(
kNumNodes
+
1
);
// too large
}
}
TYPED_TEST
(
MstSolverTest
,
SolveForAllRootsForestOnly
)
{
const
int
kNumNodes
=
10
;
const
bool
forest
=
true
;
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
1
);
// favor all root selections
this
->
AddAllArcs
(
kNumNodes
,
0
);
this
->
SolveAndExpectArgmax
({
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
});
}
TYPED_TEST
(
MstSolverTest
,
SolveForLeftToRightChain
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
for
(
int
target
=
1
;
target
<
kNumNodes
;
++
target
)
{
this
->
solver_
.
AddArc
(
target
-
1
,
target
,
1
);
// favor left-to-right chain
}
this
->
SolveAndExpectArgmax
({
0
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
});
}
}
TYPED_TEST
(
MstSolverTest
,
SolveForRightToLeftChain
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
for
(
int
source
=
1
;
source
<
kNumNodes
;
++
source
)
{
this
->
solver_
.
AddArc
(
source
,
source
-
1
,
1
);
// favor right-to-left chain
}
this
->
SolveAndExpectArgmax
({
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
9
});
}
}
TYPED_TEST
(
MstSolverTest
,
SolveForAllFromFirstTree
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
for
(
int
target
=
1
;
target
<
kNumNodes
;
++
target
)
{
this
->
solver_
.
AddArc
(
0
,
target
,
1
);
// favor first -> target
}
this
->
SolveAndExpectArgmax
({
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
});
}
}
TYPED_TEST
(
MstSolverTest
,
SolveForAllFromLastTree
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
for
(
int
target
=
0
;
target
+
1
<
kNumNodes
;
++
target
)
{
this
->
solver_
.
AddArc
(
9
,
target
,
1
);
// favor last -> target
}
this
->
SolveAndExpectArgmax
({
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
});
}
}
TYPED_TEST
(
MstSolverTest
,
SolveForBinaryTree
)
{
const
int
kNumNodes
=
15
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
for
(
int
target
=
1
;
target
<
kNumNodes
;
++
target
)
{
this
->
solver_
.
AddArc
((
target
-
1
)
/
2
,
target
,
1
);
// like a binary heap
}
this
->
SolveAndExpectArgmax
({
0
,
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
,
4
,
4
,
5
,
5
,
6
,
6
});
}
}
}
// namespace
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/mst/ops/mst_op_kernels.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <cmath>
#include <limits>
#include <type_traits>
#include <vector>
#include "dragnn/mst/mst_solver.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
namespace
syntaxnet
{
namespace
dragnn
{
// Op kernel implementation that wraps the |MstSolver|.
template
<
class
Index
,
class
Score
>
class
MaximumSpanningTreeOpKernel
:
public
tensorflow
::
OpKernel
{
public:
explicit
MaximumSpanningTreeOpKernel
(
tensorflow
::
OpKernelConstruction
*
context
)
:
tensorflow
::
OpKernel
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"forest"
,
&
forest_
));
}
void
Compute
(
tensorflow
::
OpKernelContext
*
context
)
override
{
const
tensorflow
::
Tensor
&
num_nodes_tensor
=
context
->
input
(
0
);
const
tensorflow
::
Tensor
&
scores_tensor
=
context
->
input
(
1
);
// Check ranks.
OP_REQUIRES
(
context
,
num_nodes_tensor
.
dims
()
==
1
,
tensorflow
::
errors
::
InvalidArgument
(
"num_nodes must be a vector, got shape "
,
num_nodes_tensor
.
shape
().
DebugString
()));
OP_REQUIRES
(
context
,
scores_tensor
.
dims
()
==
3
,
tensorflow
::
errors
::
InvalidArgument
(
"scores must be rank 3, got shape "
,
scores_tensor
.
shape
().
DebugString
()));
// Batch size and input dimension (B and M in the op docstring).
const
int64
batch_size
=
scores_tensor
.
shape
().
dim_size
(
0
);
const
int64
input_dim
=
scores_tensor
.
shape
().
dim_size
(
1
);
// Check shapes.
const
tensorflow
::
TensorShape
shape_b
({
batch_size
});
const
tensorflow
::
TensorShape
shape_bxm
({
batch_size
,
input_dim
});
const
tensorflow
::
TensorShape
shape_bxmxm
(
{
batch_size
,
input_dim
,
input_dim
});
OP_REQUIRES
(
context
,
num_nodes_tensor
.
shape
()
==
shape_b
,
tensorflow
::
errors
::
InvalidArgument
(
"num_nodes misshapen: got "
,
num_nodes_tensor
.
shape
().
DebugString
(),
" but expected "
,
shape_b
.
DebugString
()));
OP_REQUIRES
(
context
,
scores_tensor
.
shape
()
==
shape_bxmxm
,
tensorflow
::
errors
::
InvalidArgument
(
"scores misshapen: got "
,
scores_tensor
.
shape
().
DebugString
(),
" but expected "
,
shape_bxmxm
.
DebugString
()));
// Create outputs.
tensorflow
::
Tensor
*
max_scores_tensor
=
nullptr
;
tensorflow
::
Tensor
*
argmax_sources_tensor
=
nullptr
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
0
,
shape_b
,
&
max_scores_tensor
));
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
1
,
shape_bxm
,
&
argmax_sources_tensor
));
// Acquire shaped and typed references.
const
BatchedSizes
num_nodes_b
=
num_nodes_tensor
.
vec
<
int32
>
();
const
BatchedScores
scores_bxmxm
=
scores_tensor
.
tensor
<
Score
,
3
>
();
BatchedMaxima
max_scores_b
=
max_scores_tensor
->
vec
<
Score
>
();
BatchedSources
argmax_sources_bxm
=
argmax_sources_tensor
->
matrix
<
int32
>
();
// Solve the batch of MST problems in parallel. Set a high cycles per unit
// to encourage finer sharding.
constexpr
int64
kCyclesPerUnit
=
1000
*
1000
*
1000
;
std
::
vector
<
tensorflow
::
Status
>
statuses
(
batch_size
);
context
->
device
()
->
tensorflow_cpu_worker_threads
()
->
workers
->
ParallelFor
(
batch_size
,
kCyclesPerUnit
,
[
&
](
int64
begin
,
int64
end
)
{
for
(
int64
problem
=
begin
;
problem
<
end
;
++
problem
)
{
statuses
[
problem
]
=
RunSolver
(
problem
,
num_nodes_b
,
scores_bxmxm
,
max_scores_b
,
argmax_sources_bxm
);
}
});
for
(
const
tensorflow
::
Status
&
status
:
statuses
)
{
OP_REQUIRES_OK
(
context
,
status
);
}
}
private:
using
BatchedSizes
=
typename
tensorflow
::
TTypes
<
int32
>::
ConstVec
;
using
BatchedScores
=
typename
tensorflow
::
TTypes
<
Score
,
3
>::
ConstTensor
;
using
BatchedMaxima
=
typename
tensorflow
::
TTypes
<
Score
>::
Vec
;
using
BatchedSources
=
typename
tensorflow
::
TTypes
<
int32
>::
Matrix
;
// Solves for the maximum spanning tree of the digraph defined by the values
// at index |problem| in |num_nodes_b| and |scores_bxmxm|. On success, sets
// the values at index |problem| in |max_scores_b| and |argmax_sources_bxm|.
// On error, returns non-OK.
tensorflow
::
Status
RunSolver
(
int
problem
,
BatchedSizes
num_nodes_b
,
BatchedScores
scores_bxmxm
,
BatchedMaxima
max_scores_b
,
BatchedSources
argmax_sources_bxm
)
const
{
// Check digraph size overflow.
const
int32
num_nodes
=
num_nodes_b
(
problem
);
const
int32
input_dim
=
argmax_sources_bxm
.
dimension
(
1
);
if
(
num_nodes
>
input_dim
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"number of nodes in digraph "
,
problem
,
" overflows input dimension: got "
,
num_nodes
,
" but expected <= "
,
input_dim
);
}
if
(
num_nodes
>=
std
::
numeric_limits
<
Index
>::
max
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"number of nodes in digraph "
,
problem
,
" overflows index type: got "
,
num_nodes
,
" but expected < "
,
std
::
numeric_limits
<
Index
>::
max
());
}
const
Index
num_nodes_index
=
static_cast
<
Index
>
(
num_nodes
);
MstSolver
<
Index
,
Score
>
solver
;
TF_RETURN_IF_ERROR
(
solver
.
Init
(
forest_
,
num_nodes_index
));
// Populate the solver with arcs and root selections. Note that non-finite
// scores are treated as nonexistent arcs or roots.
for
(
Index
target
=
0
;
target
<
num_nodes_index
;
++
target
)
{
for
(
Index
source
=
0
;
source
<
num_nodes_index
;
++
source
)
{
const
Score
score
=
scores_bxmxm
(
problem
,
target
,
source
);
if
(
!
std
::
isfinite
(
score
))
continue
;
if
(
source
==
target
)
{
// root
solver
.
AddRoot
(
target
,
score
);
}
else
{
// arc
solver
.
AddArc
(
source
,
target
,
score
);
}
}
}
std
::
vector
<
Index
>
argmax
(
num_nodes
);
TF_RETURN_IF_ERROR
(
solver
.
Solve
(
&
argmax
));
// Output the tree and accumulate its score.
Score
max_score
=
0
;
for
(
Index
target
=
0
;
target
<
num_nodes_index
;
++
target
)
{
const
Index
source
=
argmax
[
target
];
argmax_sources_bxm
(
problem
,
target
)
=
source
;
max_score
+=
scores_bxmxm
(
problem
,
target
,
source
);
}
max_scores_b
(
problem
)
=
max_score
;
// Pad the source list with -1.
for
(
int32
i
=
num_nodes
;
i
<
input_dim
;
++
i
)
{
argmax_sources_bxm
(
problem
,
i
)
=
-
1
;
}
return
tensorflow
::
Status
::
OK
();
}
private:
bool
forest_
=
false
;
};
// Use Index=uint16, which allows digraphs containing up to 32,767 nodes.
REGISTER_KERNEL_BUILDER
(
Name
(
"MaximumSpanningTree"
)
.
Device
(
tensorflow
::
DEVICE_CPU
)
.
TypeConstraint
<
int32
>
(
"T"
),
MaximumSpanningTreeOpKernel
<
uint16
,
int32
>
);
REGISTER_KERNEL_BUILDER
(
Name
(
"MaximumSpanningTree"
)
.
Device
(
tensorflow
::
DEVICE_CPU
)
.
TypeConstraint
<
float
>
(
"T"
),
MaximumSpanningTreeOpKernel
<
uint16
,
float
>
);
REGISTER_KERNEL_BUILDER
(
Name
(
"MaximumSpanningTree"
)
.
Device
(
tensorflow
::
DEVICE_CPU
)
.
TypeConstraint
<
double
>
(
"T"
),
MaximumSpanningTreeOpKernel
<
uint16
,
double
>
);
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/mst/ops/mst_ops.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace
syntaxnet
{
namespace
dragnn
{
REGISTER_OP
(
"MaximumSpanningTree"
)
.
Attr
(
"T: {int32, float, double}"
)
.
Attr
(
"forest: bool = false"
)
.
Input
(
"num_nodes: int32"
)
.
Input
(
"scores: T"
)
.
Output
(
"max_scores: T"
)
.
Output
(
"argmax_sources: int32"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
tensorflow
::
shape_inference
::
ShapeHandle
num_nodes
;
tensorflow
::
shape_inference
::
ShapeHandle
scores
;
TF_RETURN_IF_ERROR
(
context
->
WithRank
(
context
->
input
(
0
),
1
,
&
num_nodes
));
TF_RETURN_IF_ERROR
(
context
->
WithRank
(
context
->
input
(
1
),
3
,
&
scores
));
// Extract dimensions while asserting that they match.
tensorflow
::
shape_inference
::
DimensionHandle
batch_size
;
// aka "B"
TF_RETURN_IF_ERROR
(
context
->
Merge
(
context
->
Dim
(
num_nodes
,
0
),
context
->
Dim
(
scores
,
0
),
&
batch_size
));
tensorflow
::
shape_inference
::
DimensionHandle
max_nodes
;
// aka "M"
TF_RETURN_IF_ERROR
(
context
->
Merge
(
context
->
Dim
(
scores
,
1
),
context
->
Dim
(
scores
,
2
),
&
max_nodes
));
context
->
set_output
(
0
,
context
->
Vector
(
batch_size
));
context
->
set_output
(
1
,
context
->
Matrix
(
batch_size
,
max_nodes
));
return
tensorflow
::
Status
::
OK
();
})
.
Doc
(
R"doc(
Finds the maximum directed spanning tree of a digraph.
Given a batch of digraphs with scored arcs and root selections, solves for the
maximum spanning tree of each digraph, where the score of a tree is defined as
the sum of the scores of the arcs and roots making up the tree.
Returns the score of the maximum spanning tree of each digraph, as well as the
arcs and roots in that tree. Each digraph in a batch may contain a different
number of nodes, so the sizes of the digraphs must be provided as an input.
Note that this operation is only differentiable w.r.t. its |scores| input and
its |max_scores| output.
forest: If true, solves for a maximum spanning forest instead of a maximum
spanning tree, where a spanning forest is a set of disjoint trees that
span the nodes of the digraph.
num_nodes: [B] vector where entry b is number of nodes in the b'th digraph.
scores: [B,M,M] tensor where entry b,t,s is the score of the arc from s to t in
the b'th digraph, if s!=t, or the score of selecting t as a root in the
b'th digraph, if s==t. Requires that M is >= num_nodes[b], for all b,
and ignores entries b,s,t where s or t is >= num_nodes[b]. Arcs or root
selections with non-finite score are treated as nonexistent.
max_scores: [B] vector where entry b is the score of the maximum spanning tree
of the b'th digraph.
argmax_sources: [B,M] matrix where entry b,t is the source of the arc inbound to
t in the maximum spanning tree of the b'th digraph, or t if t is
a root. Entries b,t where t is >= num_nodes[b] are set to -1.
)doc"
);
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/mst/spanning_tree_iterator.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/spanning_tree_iterator.h"
namespace
syntaxnet
{
namespace
dragnn
{
SpanningTreeIterator
::
SpanningTreeIterator
(
bool
forest
)
:
forest_
(
forest
)
{}
bool
SpanningTreeIterator
::
HasCycle
(
const
SourceList
&
sources
)
{
// Flags for whether each node has already been searched.
searched_
.
assign
(
sources
.
size
(),
false
);
// Flags for whether the search is currently visiting each node.
visiting_
.
assign
(
sources
.
size
(),
false
);
// Search upwards from each node to find cycles.
for
(
uint32
initial_node
=
0
;
initial_node
<
sources
.
size
();
++
initial_node
)
{
// Search upwards to try to find a cycle.
uint32
current_node
=
initial_node
;
while
(
true
)
{
if
(
searched_
[
current_node
])
break
;
// already searched
if
(
visiting_
[
current_node
])
return
true
;
// revisiting implies cycle
visiting_
[
current_node
]
=
true
;
// mark as being currently visited
const
uint32
source_node
=
sources
[
current_node
];
if
(
source_node
==
current_node
)
break
;
// self-loops are roots
current_node
=
source_node
;
// advance upwards
}
// No cycle; search upwards again to update flags.
current_node
=
initial_node
;
while
(
true
)
{
if
(
searched_
[
current_node
])
break
;
// already searched
searched_
[
current_node
]
=
true
;
visiting_
[
current_node
]
=
false
;
const
uint32
source_node
=
sources
[
current_node
];
if
(
source_node
==
current_node
)
break
;
// self-loops are roots
current_node
=
source_node
;
// advance upwards
}
}
return
false
;
}
uint32
SpanningTreeIterator
::
NumRoots
(
const
SourceList
&
sources
)
{
uint32
num_roots
=
0
;
for
(
uint32
node
=
0
;
node
<
sources
.
size
();
++
node
)
{
num_roots
+=
(
node
==
sources
[
node
]);
}
return
num_roots
;
}
bool
SpanningTreeIterator
::
NextSourceList
(
SourceList
*
sources
)
{
const
uint32
num_nodes
=
sources
->
size
();
for
(
uint32
i
=
0
;
i
<
num_nodes
;
++
i
)
{
const
uint32
new_source
=
++
(
*
sources
)[
i
];
if
(
new_source
<
num_nodes
)
return
true
;
// absorbed in this digit
(
*
sources
)[
i
]
=
0
;
// overflowed this digit, carry to next digit
}
return
false
;
// overflowed the last digit
}
bool
SpanningTreeIterator
::
NextTree
(
SourceList
*
sources
)
{
// Iterate source lists, skipping non-trees.
while
(
NextSourceList
(
sources
))
{
// Check the number of roots.
const
uint32
num_roots
=
NumRoots
(
*
sources
);
if
(
forest_
)
{
if
(
num_roots
==
0
)
continue
;
}
else
{
if
(
num_roots
!=
1
)
continue
;
}
// Check for cycles.
if
(
HasCycle
(
*
sources
))
continue
;
// Acyclic and rooted, therefore tree.
return
true
;
}
return
false
;
}
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/mst/spanning_tree_iterator.h
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_MST_SPANNING_TREE_ITERATOR_H_
#define DRAGNN_MST_SPANNING_TREE_ITERATOR_H_
#include <vector>
#include "syntaxnet/base.h"
namespace
syntaxnet
{
namespace
dragnn
{
// A class that iterates over all possible spanning trees of a complete digraph.
// Thread-compatible. Useful for brute-force comparison tests.
//
// TODO(googleuser): Try using Prufer sequences, which are more efficient to
// enumerate as there are no non-trees to filter out.
class
SpanningTreeIterator
{
public:
// An array that provides the source of the inbound arc for each node. Roots
// are represented as self-loops.
using
SourceList
=
std
::
vector
<
uint32
>
;
// Creates a spanning tree iterator. If |forest| is true, then this iterates
// over forests instead of trees (i.e., multiple roots are allowed).
explicit
SpanningTreeIterator
(
bool
forest
);
// Applies the |functor| to all spanning trees (or forests, if |forest_| is
// true) of a complete digraph containing |num_nodes| nodes. Each tree is
// passed to the |functor| as a SourceList.
template
<
class
Functor
>
void
ForEachTree
(
uint32
num_nodes
,
Functor
functor
)
{
// Conveniently, the all-zero vector represents a valid tree.
SourceList
sources
(
num_nodes
,
0
);
do
{
functor
(
sources
);
}
while
(
NextTree
(
&
sources
));
}
private:
// Returns true if the |sources| contains a cycle.
bool
HasCycle
(
const
SourceList
&
sources
);
// Returns the number of roots in the |sources|.
static
uint32
NumRoots
(
const
SourceList
&
sources
);
// Advances |sources| to the next source list, or returns false if there are
// no more source lists.
static
bool
NextSourceList
(
SourceList
*
sources
);
// Advances |sources| to the next tree (or forest, if |forest_| is true), or
// returns false if there are no more trees.
bool
NextTree
(
SourceList
*
sources
);
// If true, iterate over spanning forests instead of spanning trees.
const
bool
forest_
;
// Workspaces used by the search in HasCycle().
std
::
vector
<
bool
>
searched_
;
std
::
vector
<
bool
>
visiting_
;
};
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_MST_SPANNING_TREE_ITERATOR_H_
research/syntaxnet/dragnn/mst/spanning_tree_iterator_test.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/spanning_tree_iterator.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
{
// Testing rig. When the bool parameter is true, iterates over spanning forests
// instead of spanning trees.
class
SpanningTreeIteratorTest
:
public
::
testing
::
TestWithParam
<
bool
>
{
protected:
using
SourceList
=
SpanningTreeIterator
::
SourceList
;
// Returns |base|^|exponent|. Computes the value as an integer to avoid
// rounding issues.
static
int
Pow
(
int
base
,
int
exponent
)
{
double
real_product
=
1.0
;
int
product
=
1
;
for
(
int
i
=
0
;
i
<
exponent
;
++
i
)
{
product
*=
base
;
real_product
*=
base
;
}
CHECK_EQ
(
product
,
real_product
)
<<
"Overflow detected."
;
return
product
;
}
// Expects that the number of possible spanning trees for a complete digraph
// of |num_nodes| nodes is |expected_num_trees|.
void
ExpectNumTrees
(
int
num_nodes
,
int
expected_num_trees
)
{
int
actual_num_trees
=
0
;
iterator_
.
ForEachTree
(
num_nodes
,
[
&
](
const
SourceList
&
sources
)
{
++
actual_num_trees
;
});
LOG
(
INFO
)
<<
"num_nodes="
<<
num_nodes
<<
" expected_num_trees="
<<
expected_num_trees
<<
" actual_num_trees="
<<
actual_num_trees
;
EXPECT_EQ
(
expected_num_trees
,
actual_num_trees
);
}
// Expects that the set of possible spanning trees for a complete digraph of
// |num_nodes| nodes is |expected_trees|.
void
ExpectTrees
(
int
num_nodes
,
const
std
::
set
<
SourceList
>
&
expected_trees
)
{
std
::
set
<
SourceList
>
actual_trees
;
iterator_
.
ForEachTree
(
num_nodes
,
[
&
](
const
SourceList
&
sources
)
{
CHECK
(
actual_trees
.
insert
(
sources
).
second
);
});
EXPECT_EQ
(
expected_trees
,
actual_trees
);
}
// Instance for tests. Shared across assertions in a test to exercise reuse.
SpanningTreeIterator
iterator_
{
GetParam
()};
};
INSTANTIATE_TEST_CASE_P
(
AllowForest
,
SpanningTreeIteratorTest
,
::
testing
::
Bool
());
TEST_P
(
SpanningTreeIteratorTest
,
NumberOfTrees
)
{
// According to Cayley's formula, the number of undirected spanning trees on a
// complete graph of n nodes is n^{n-2}:
// https://en.wikipedia.org/wiki/Cayley%27s_formula
//
// To count the number of directed spanning trees, note that each undirected
// spanning tree gives rise to n directed spanning trees: choose one of the n
// nodes as the root, and then orient arcs outwards. Therefore, the number of
// directed spanning trees on a complete digraph of n nodes is n^{n-1}.
//
// To count the number of directed spanning forests, consider undirected
// spanning trees on a complete graph of n+1 nodes. Arbitrarily select one
// node as the artificial root, orient arcs outwards, and then delete the
// artificial root and its outbound arcs. The result is a directed spanning
// forest on n nodes. Therefore, the number of directed spanning forests on a
// complete digraph of n nodes is (n+1)^{n-1}.
for
(
int
num_nodes
=
1
;
num_nodes
<=
7
;
++
num_nodes
)
{
if
(
GetParam
())
{
// forest
ExpectNumTrees
(
num_nodes
,
Pow
(
num_nodes
+
1
,
num_nodes
-
1
));
}
else
{
// tree
ExpectNumTrees
(
num_nodes
,
Pow
(
num_nodes
,
num_nodes
-
1
));
}
}
}
TEST_P
(
SpanningTreeIteratorTest
,
OneNodeDigraph
)
{
ExpectTrees
(
1
,
{{
0
}});
}
TEST_P
(
SpanningTreeIteratorTest
,
TwoNodeDigraph
)
{
if
(
GetParam
())
{
// forest
ExpectTrees
(
2
,
{{
0
,
0
},
{
0
,
1
},
{
1
,
1
}});
// {0, 1} is two-root structure
}
else
{
// tree
ExpectTrees
(
2
,
{{
0
,
0
},
{
1
,
1
}});
}
}
TEST_P
(
SpanningTreeIteratorTest
,
ThreeNodeDigraph
)
{
if
(
GetParam
())
{
// forest
ExpectTrees
(
3
,
{{
0
,
0
,
0
},
{
0
,
0
,
1
},
{
0
,
0
,
2
},
// 2-root
{
0
,
1
,
0
},
// 2-root
{
0
,
1
,
1
},
// 2-root
{
0
,
1
,
2
},
// 3-root
{
0
,
2
,
0
},
{
0
,
2
,
2
},
// 2-root
{
1
,
1
,
0
},
{
1
,
1
,
1
},
{
1
,
1
,
2
},
// 2-root
{
1
,
2
,
2
},
{
2
,
0
,
2
},
{
2
,
1
,
1
},
{
2
,
1
,
2
},
// 2-root
{
2
,
2
,
2
}});
}
else
{
// tree
ExpectTrees
(
3
,
{{
0
,
0
,
0
},
{
0
,
0
,
1
},
{
0
,
2
,
0
},
{
1
,
1
,
0
},
{
1
,
1
,
1
},
{
1
,
2
,
2
},
{
2
,
0
,
2
},
{
2
,
1
,
1
},
{
2
,
2
,
2
}});
}
}
}
// namespace
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/protos/BUILD
View file @
80178fc6
...
...
@@ -2,48 +2,63 @@ package(default_visibility = ["//visibility:public"])
load
(
"//syntaxnet:syntaxnet.bzl"
,
"tf_proto_library"
,
"tf_proto_library
_cc
"
,
"tf_proto_library_py"
,
)
# Protos.
tf_proto_library
(
tf_proto_library
_cc
(
name
=
"data_proto"
,
srcs
=
[
"data.proto"
],
)
tf_proto_library
(
tf_proto_library
_cc
(
name
=
"trace_proto"
,
srcs
=
[
"trace.proto"
],
deps
=
[
":data_proto"
,
],
protodeps
=
[
":data_proto"
],
)
tf_proto_library
(
tf_proto_library_cc
(
name
=
"cell_trace_proto"
,
srcs
=
[
"cell_trace.proto"
],
protodeps
=
[
":trace_proto"
],
)
tf_proto_library_cc
(
name
=
"spec_proto"
,
srcs
=
[
"spec.proto"
],
)
tf_proto_library
(
tf_proto_library
_cc
(
name
=
"runtime_proto"
,
srcs
=
[
"runtime.proto"
],
deps
=
[
":spec_proto"
],
protodeps
=
[
":spec_proto"
],
)
tf_proto_library_cc
(
name
=
"export_proto"
,
srcs
=
[
"export.proto"
],
protodeps
=
[
":spec_proto"
],
)
tf_proto_library_py
(
name
=
"data_
py_
pb2"
,
name
=
"data_pb2"
,
srcs
=
[
"data.proto"
],
)
tf_proto_library_py
(
name
=
"trace_
py_
pb2"
,
name
=
"trace_pb2"
,
srcs
=
[
"trace.proto"
],
deps
=
[
":data_
py_
pb2"
],
proto
deps
=
[
":data_pb2"
],
)
tf_proto_library_py
(
name
=
"spec_
py_
pb2"
,
name
=
"spec_pb2"
,
srcs
=
[
"spec.proto"
],
)
tf_proto_library_py
(
name
=
"export_pb2"
,
srcs
=
[
"export.proto"
],
)
research/syntaxnet/dragnn/protos/cell_trace.proto
0 → 100644
View file @
80178fc6
syntax
=
"proto2"
;
import
"dragnn/protos/trace.proto"
;
package
syntaxnet
.
dragnn.runtime
;
// Trace of a network cell computation (e.g., an LSTM cell).
// NEXT ID: 4
message
CellTrace
{
extend
ComponentStepTrace
{
// Cell computations that occurred in the step. It's possible that there is
// more than one cell per component (e.g., a bi-LSTM component).
repeated
CellTrace
step_trace_extension
=
169167178
;
}
// Name of the cell.
optional
string
name
=
1
;
// Tensors making up the cell. Note that this only includes local variables
// (e.g., activation vectors), not global constants (e.g., weight matrices).
repeated
CellTensorTrace
tensor
=
2
;
// Operations making up the cell. Note that the operation inputs may refer to
// global constants that are not present in |tensor|.
repeated
CellOperationTrace
operation
=
3
;
}
// Trace of a tensor in a cell computation.
// NEXT ID: 7
message
CellTensorTrace
{
// Possible orderings of the dimensions.
enum
Order
{
ORDER_UNKNOWN
=
0
;
// unspecified or unknown
ORDER_ROW_MAJOR
=
1
;
// row-major: dimension 0 has largest stride
ORDER_COLUMN_MAJOR
=
2
;
// column-major: dimension 0 has smallest stride
}
// Name of the tensor (e.g., "annotation/inference_rnn/split:1").
optional
string
name
=
1
;
// Data type of the tensor (e.g., "DT_FLOAT").
optional
string
type
=
2
;
// Dimensions of the tensor (e.g., [1, 65]).
repeated
int32
dimension
=
3
;
// Alignment-padded dimensions of the tensor (e.g., [1, 96]).
repeated
int32
aligned_dimension
=
4
;
// Ordering of the tensor values.
optional
Order
order
=
5
[
default
=
ORDER_UNKNOWN
];
// Block of alignment-padded values. For simplicity, values of all types are
// converted to double (via C++ conversion rules). Use |aligned_dimension| to
// traverse the values, but note that |dimension| bounds the valid region.
repeated
double
value
=
6
;
}
// Trace of an operation in a cell computation.
// NEXT ID: 6
message
CellOperationTrace
{
// Name of the operation (e.g., "annotation/inference_rnn/MatMul").
optional
string
name
=
1
;
// High-level type of the operation (e.g., "MatMul").
optional
string
type
=
2
;
// Kernel that implements the operation, if applicable (e.g., "AvxFltMatMul").
optional
string
kernel
=
3
;
// Names of input tensors of the operation, in order.
repeated
string
input
=
4
;
// Names of output tensors of the operation, in order.
repeated
string
output
=
5
;
}
research/syntaxnet/dragnn/protos/data.proto
View file @
80178fc6
// DRAGNN data proto. See go/dragnn-design for more information.
// DRAGNN data proto.
syntax
=
"proto2"
;
...
...
research/syntaxnet/dragnn/protos/export.proto
0 → 100644
View file @
80178fc6
syntax
=
"proto2"
;
import
"dragnn/protos/spec.proto"
;
package
syntaxnet
.
dragnn.runtime
;
// Specification of a subgraph of TF nodes that make up a network cell.
//
// Roughly speaking, a "cell" consists of the "pure math" parts of a DRAGNN
// component, and is intended to be exported to a NN compiler. The set of
// operations that make up a cell may change over time, but currently the
// boundaries of a cell are:
//
// Inputs:
// * Fixed feature IDs.
// * Linked feature embeddings, before pass_through_embedding_matrix().
// * Recurrent context tensors.
//
// Outputs:
// * Network unit layers.
message
CellSubgraphSpec
{
// An input to the subgraph.
message
Input
{
// Possible types of input.
enum
Type
{
TYPE_UNKNOWN
=
0
;
// An input derived from a fixed or linked feature.
TYPE_FEATURE
=
1
;
// An input that refers to an output of the previous iteration of the
// transition loop. The input must have the same name as the output to
// which it refers. On the first iteration, its value is zero.
//
// This is used by, e.g., LSTMNetwork, which reads its cell state from the
// context_tensor_arrays instead of from a linked feature.
TYPE_RECURRENT
=
2
;
}
// Logical name of the input (e.g., "lstm_c", "linked_feature_0"). Must be
// unique among the inputs of the cell.
optional
string
name
=
1
;
// Tensor containing the input (e.g., "annotation/rnn/split:1"). Must be
// unique among the inputs of the cell.
optional
string
tensor
=
2
;
// Type of input.
optional
Type
type
=
3
[
default
=
TYPE_UNKNOWN
];
}
// An output of the subgraph.
message
Output
{
// Logical name of the output (e.g., "lstm_c", "layer_0"). Must be unique
// among the outputs of the cell.
optional
string
name
=
1
;
// Tensor containing the output (e.g., "annotation/rnn/split:1"). Need not
// be unique; duplicate outputs for the same tensor are treated as aliases.
optional
string
tensor
=
2
;
}
// Inputs of the subgraph.
repeated
Input
input
=
1
;
// Outputs of the subgraph.
repeated
Output
output
=
2
;
}
// Additional information to compile a component.
//
// NEXT ID: 3
message
CompilationSpec
{
extend
ComponentSpec
{
optional
CompilationSpec
component_spec_extension
=
174770970
;
}
// A unique name of the entire DRAGNN model where this component is used.
optional
string
model_name
=
1
;
// The subgraph specification for this component.
optional
CellSubgraphSpec
cell_subgraph_spec
=
2
;
}
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment