Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
80178fc6
Unverified
Commit
80178fc6
authored
May 11, 2018
by
Mark Omernick
Committed by
GitHub
May 11, 2018
Browse files
Merge pull request #4153 from terryykoo/master
Export @195097388.
parents
a84e1ef9
edea2b67
Changes
145
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2316 additions
and
19 deletions
+2316
-19
research/syntaxnet/dragnn/core/test/mock_compute_session.h
research/syntaxnet/dragnn/core/test/mock_compute_session.h
+3
-2
research/syntaxnet/dragnn/core/util/BUILD
research/syntaxnet/dragnn/core/util/BUILD
+9
-0
research/syntaxnet/dragnn/core/util/label.h
research/syntaxnet/dragnn/core/util/label.h
+45
-0
research/syntaxnet/dragnn/io/BUILD
research/syntaxnet/dragnn/io/BUILD
+3
-3
research/syntaxnet/dragnn/mst/BUILD
research/syntaxnet/dragnn/mst/BUILD
+116
-0
research/syntaxnet/dragnn/mst/README.md
research/syntaxnet/dragnn/mst/README.md
+3
-0
research/syntaxnet/dragnn/mst/disjoint_set_forest.h
research/syntaxnet/dragnn/mst/disjoint_set_forest.h
+183
-0
research/syntaxnet/dragnn/mst/disjoint_set_forest_test.cc
research/syntaxnet/dragnn/mst/disjoint_set_forest_test.cc
+150
-0
research/syntaxnet/dragnn/mst/mst_solver.h
research/syntaxnet/dragnn/mst/mst_solver.h
+587
-0
research/syntaxnet/dragnn/mst/mst_solver_random_comparison_test.cc
...syntaxnet/dragnn/mst/mst_solver_random_comparison_test.cc
+183
-0
research/syntaxnet/dragnn/mst/mst_solver_test.cc
research/syntaxnet/dragnn/mst/mst_solver_test.cc
+255
-0
research/syntaxnet/dragnn/mst/ops/mst_op_kernels.cc
research/syntaxnet/dragnn/mst/ops/mst_op_kernels.cc
+193
-0
research/syntaxnet/dragnn/mst/ops/mst_ops.cc
research/syntaxnet/dragnn/mst/ops/mst_ops.cc
+78
-0
research/syntaxnet/dragnn/mst/spanning_tree_iterator.cc
research/syntaxnet/dragnn/mst/spanning_tree_iterator.cc
+97
-0
research/syntaxnet/dragnn/mst/spanning_tree_iterator.h
research/syntaxnet/dragnn/mst/spanning_tree_iterator.h
+79
-0
research/syntaxnet/dragnn/mst/spanning_tree_iterator_test.cc
research/syntaxnet/dragnn/mst/spanning_tree_iterator_test.cc
+143
-0
research/syntaxnet/dragnn/protos/BUILD
research/syntaxnet/dragnn/protos/BUILD
+28
-13
research/syntaxnet/dragnn/protos/cell_trace.proto
research/syntaxnet/dragnn/protos/cell_trace.proto
+76
-0
research/syntaxnet/dragnn/protos/data.proto
research/syntaxnet/dragnn/protos/data.proto
+2
-1
research/syntaxnet/dragnn/protos/export.proto
research/syntaxnet/dragnn/protos/export.proto
+83
-0
No files found.
research/syntaxnet/dragnn/core/test/mock_compute_session.h
View file @
80178fc6
...
@@ -62,13 +62,14 @@ class MockComputeSession : public ComputeSession {
...
@@ -62,13 +62,14 @@ class MockComputeSession : public ComputeSession {
MOCK_METHOD2
(
GetTranslatedLinkFeatures
,
MOCK_METHOD2
(
GetTranslatedLinkFeatures
,
std
::
vector
<
LinkFeatures
>
(
const
string
&
component_name
,
std
::
vector
<
LinkFeatures
>
(
const
string
&
component_name
,
int
channel_id
));
int
channel_id
));
MOCK_METHOD1
(
EmitOracleLabels
,
MOCK_METHOD1
(
EmitOracleLabels
,
std
::
vector
<
std
::
vector
<
std
::
vector
<
Label
>>>
(
std
::
vector
<
std
::
vector
<
int
>>
(
const
string
&
component_name
));
const
string
&
component_name
));
MOCK_METHOD1
(
IsTerminal
,
bool
(
const
string
&
component_name
));
MOCK_METHOD1
(
IsTerminal
,
bool
(
const
string
&
component_name
));
MOCK_METHOD1
(
FinalizeData
,
void
(
const
string
&
component_name
));
MOCK_METHOD1
(
FinalizeData
,
void
(
const
string
&
component_name
));
MOCK_METHOD0
(
GetSerializedPredictions
,
std
::
vector
<
string
>
());
MOCK_METHOD0
(
GetSerializedPredictions
,
std
::
vector
<
string
>
());
MOCK_METHOD0
(
GetTraceProtos
,
std
::
vector
<
MasterTrace
>
());
MOCK_METHOD0
(
GetTraceProtos
,
std
::
vector
<
MasterTrace
>
());
MOCK_METHOD1
(
SetInputData
,
void
(
const
std
::
vector
<
string
>
&
data
));
MOCK_METHOD1
(
SetInputData
,
void
(
const
std
::
vector
<
string
>
&
data
));
MOCK_METHOD0
(
GetInputBatchCache
,
InputBatchCache
*
());
MOCK_METHOD0
(
ResetSession
,
void
());
MOCK_METHOD0
(
ResetSession
,
void
());
MOCK_METHOD1
(
SetTracing
,
void
(
bool
tracing_on
));
MOCK_METHOD1
(
SetTracing
,
void
(
bool
tracing_on
));
MOCK_CONST_METHOD0
(
Id
,
int
());
MOCK_CONST_METHOD0
(
Id
,
int
());
...
...
research/syntaxnet/dragnn/core/util/BUILD
0 → 100644
View file @
80178fc6
package
(
default_visibility
=
[
"//visibility:public"
],
features
=
[
"-layering_check"
],
)
cc_library
(
name
=
"label"
,
hdrs
=
[
"label.h"
],
)
research/syntaxnet/dragnn/core/util/label.h
0 → 100644
View file @
80178fc6
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_CORE_UTIL_LABEL_H_
#define DRAGNN_CORE_UTIL_LABEL_H_
#include <cmath>
namespace
syntaxnet
{
namespace
dragnn
{
// Stores label information.
struct
Label
{
Label
(
int
label_id
,
float
label_probability
)
:
id
(
label_id
),
probability
(
label_probability
)
{}
explicit
Label
(
int
label_id
)
:
id
(
label_id
)
{}
// Two Labels are equal if the ids match and the probabilities are within an
// epsilon of one another.
bool
operator
==
(
const
Label
&
label
)
const
{
return
(
id
==
label
.
id
)
&&
std
::
fabs
(
probability
-
label
.
probability
)
<
0.00001
;
}
// Label id and probability.
int
id
;
float
probability
=
1.0
;
};
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_CORE_UTIL_LABEL_H_
research/syntaxnet/dragnn/io/BUILD
View file @
80178fc6
...
@@ -8,7 +8,7 @@ cc_library(
...
@@ -8,7 +8,7 @@ cc_library(
":syntaxnet_sentence"
,
":syntaxnet_sentence"
,
"//dragnn/core/interfaces:input_batch"
,
"//dragnn/core/interfaces:input_batch"
,
"//syntaxnet:base"
,
"//syntaxnet:base"
,
"//syntaxnet:sentence_proto"
,
"//syntaxnet:sentence_proto
_cc
"
,
],
],
)
)
...
@@ -16,7 +16,7 @@ cc_library(
...
@@ -16,7 +16,7 @@ cc_library(
name
=
"syntaxnet_sentence"
,
name
=
"syntaxnet_sentence"
,
hdrs
=
[
"syntaxnet_sentence.h"
],
hdrs
=
[
"syntaxnet_sentence.h"
],
deps
=
[
deps
=
[
"//syntaxnet:sentence_proto"
,
"//syntaxnet:sentence_proto
_cc
"
,
"//syntaxnet:workspace"
,
"//syntaxnet:workspace"
,
],
],
)
)
...
@@ -27,7 +27,7 @@ cc_test(
...
@@ -27,7 +27,7 @@ cc_test(
deps
=
[
deps
=
[
":sentence_input_batch"
,
":sentence_input_batch"
,
"//dragnn/core/test:generic"
,
"//dragnn/core/test:generic"
,
"//syntaxnet:sentence_proto"
,
"//syntaxnet:sentence_proto
_cc
"
,
"//syntaxnet:test_main"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
"@org_tensorflow//tensorflow/core:test"
,
],
],
...
...
research/syntaxnet/dragnn/mst/BUILD
0 → 100644
View file @
80178fc6
package
(
default_visibility
=
[
"//visibility:public"
])
cc_library
(
name
=
"disjoint_set_forest"
,
hdrs
=
[
"disjoint_set_forest.h"
],
deps
=
[
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"disjoint_set_forest_test"
,
size
=
"small"
,
srcs
=
[
"disjoint_set_forest_test.cc"
],
deps
=
[
":disjoint_set_forest"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"spanning_tree_iterator"
,
testonly
=
1
,
srcs
=
[
"spanning_tree_iterator.cc"
],
hdrs
=
[
"spanning_tree_iterator.h"
],
deps
=
[
"//syntaxnet:base"
,
],
)
cc_test
(
name
=
"spanning_tree_iterator_test"
,
size
=
"small"
,
srcs
=
[
"spanning_tree_iterator_test.cc"
],
deps
=
[
":spanning_tree_iterator"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_library
(
name
=
"mst_solver"
,
hdrs
=
[
"mst_solver.h"
],
deps
=
[
":disjoint_set_forest"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
)
cc_test
(
name
=
"mst_solver_test"
,
size
=
"small"
,
srcs
=
[
"mst_solver_test.cc"
],
deps
=
[
":mst_solver"
,
"//dragnn/core/test:generic"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
cc_test
(
name
=
"mst_solver_random_comparison_test"
,
size
=
"small"
,
timeout
=
"long"
,
srcs
=
[
"mst_solver_random_comparison_test.cc"
],
tags
=
[
"manual"
,
# exclude from :all, since this is expensive
],
deps
=
[
":mst_solver"
,
":spanning_tree_iterator"
,
"//syntaxnet:base"
,
"//syntaxnet:test_main"
,
"@org_tensorflow//tensorflow/core:lib"
,
"@org_tensorflow//tensorflow/core:test"
,
],
)
load
(
"@org_tensorflow//tensorflow:tensorflow.bzl"
,
"tf_gen_op_libs"
,
"tf_gen_op_wrapper_py"
,
)
tf_gen_op_libs
(
op_lib_names
=
[
"mst_ops"
],
)
# Don't use this library directly; instead use "dragnn/python:mst_ops".
tf_gen_op_wrapper_py
(
name
=
"mst_ops"
,
visibility
=
[
"//dragnn/python:__pkg__"
],
deps
=
[
":mst_ops_op_lib"
],
)
cc_library
(
name
=
"mst_ops_cc"
,
srcs
=
[
"ops/mst_op_kernels.cc"
,
"ops/mst_ops.cc"
,
],
deps
=
[
":mst_solver"
,
"//syntaxnet:base"
,
"@org_tensorflow//tensorflow/core:framework_headers_lib"
,
"@org_tensorflow//tensorflow/core:lib"
,
],
alwayslink
=
1
,
)
research/syntaxnet/dragnn/mst/README.md
0 → 100644
View file @
80178fc6
Package for solving max-spanning-tree (MST) problems. The code here is intended
for NLP applications, but attempts to remain agnostic to particular NLP tasks
(such as dependency parsing).
research/syntaxnet/dragnn/mst/disjoint_set_forest.h
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_MST_DISJOINT_SET_FOREST_H_
#define DRAGNN_MST_DISJOINT_SET_FOREST_H_
#include <stddef.h>
#include <type_traits>
#include <vector>
#include "tensorflow/core/platform/logging.h"
namespace
syntaxnet
{
namespace
dragnn
{
// An implementation of the disjoint-set forest data structure. The universe of
// elements is the dense range of indices [0,n). Thread-compatible.
//
// By default, this uses the path compression and union by rank optimizations,
// achieving near-constant runtime on all operations. However, the user may
// disable the union by rank optimization, which allows the user to control how
// roots are selected when a union occurs. When union by rank is disabled, the
// runtime of all operations increases to O(log n) amortized.
//
// Template args:
// Index: An unsigned integral type wide enough to hold n.
// kUseUnionByRank: Whether to use the union by rank optimization.
template
<
class
Index
,
bool
kUseUnionByRank
=
true
>
class
DisjointSetForest
{
public:
static_assert
(
std
::
is_integral
<
Index
>::
value
,
"Index must be integral"
);
static_assert
(
!
std
::
is_signed
<
Index
>::
value
,
"Index must be unsigned"
);
using
IndexType
=
Index
;
// Creates an empty forest.
DisjointSetForest
()
=
default
;
// Initializes this to hold the elements [0,|size|), each initially in its own
// singleton set. Replaces existing state, if any.
void
Init
(
Index
size
);
// Returns the root of the set containing |element|, which uniquely identifies
// the set. Note that the root of a set may change as the set is merged with
// other sets; do not cache the return value of FindRoot(e) across calls to
// Union() or UnionOfRoots() that could merge the set containing e.
Index
FindRoot
(
Index
element
);
// For convenience, returns true if |element1| and |element2| are in the same
// set. When performing a large batch of queries it may be more efficient to
// cache the value of FindRoot(), modulo caveats regarding caching above.
bool
SameSet
(
Index
element1
,
Index
element2
);
// Merges the sets rooted at |root1| and |root2|, which must be the roots of
// their respective sets. Either |root1| or |root2| will be the root of the
// merged set. If |kUseUnionByRank| is true, then it is unspecified whether
// |root1| or |root2| will be the root; otherwise, |root2| will be the root.
void
UnionOfRoots
(
Index
root1
,
Index
root2
);
// As above, but for convenience finds the root of |element1| and |element2|.
void
Union
(
Index
element1
,
Index
element2
);
// The number of elements in this.
Index
size
()
const
{
return
size_
;
}
private:
// The number of elements in the universe underlying the sets.
Index
size_
=
0
;
// The parent of each element, where self-loops are roots.
std
::
vector
<
Index
>
parents_
;
// The rank of each element, for the union by rank optimization. Only used if
// |kUseUnionByRank| is true.
std
::
vector
<
Index
>
ranks_
;
};
// Implementation details below.
template
<
class
Index
,
bool
kUseUnionByRank
>
void
DisjointSetForest
<
Index
,
kUseUnionByRank
>::
Init
(
Index
size
)
{
size_
=
size
;
parents_
.
resize
(
size_
);
if
(
kUseUnionByRank
)
ranks_
.
resize
(
size_
);
// Create singleton sets.
for
(
Index
i
=
0
;
i
<
size_
;
++
i
)
{
parents_
[
i
]
=
i
;
if
(
kUseUnionByRank
)
ranks_
[
i
]
=
0
;
}
}
template
<
class
Index
,
bool
kUseUnionByRank
>
Index
DisjointSetForest
<
Index
,
kUseUnionByRank
>::
FindRoot
(
Index
element
)
{
DCHECK_LT
(
element
,
size
());
Index
*
const
__restrict
parents
=
parents_
.
data
();
// Walk up to the root of the |element|. Unroll the first two comparisons
// because path compression ensures most FindRoot() calls end there. In
// addition, if a root is found within the first two comparisons, then the
// path compression updates can be skipped.
Index
current
=
element
;
Index
parent
=
parents
[
current
];
if
(
current
==
parent
)
return
current
;
// |element| is a root
current
=
parent
;
parent
=
parents
[
current
];
if
(
current
==
parent
)
return
current
;
// |element| is the child of a root
do
{
// otherwise, continue upwards until root
current
=
parent
;
parent
=
parents
[
current
];
}
while
(
current
!=
parent
);
const
Index
root
=
current
;
// Apply path compression on the traversed nodes.
current
=
element
;
parent
=
parents
[
current
];
// not root, thanks to unrolling above
do
{
parents
[
current
]
=
root
;
current
=
parent
;
parent
=
parents
[
current
];
}
while
(
parent
!=
root
);
return
root
;
}
template
<
class
Index
,
bool
kUseUnionByRank
>
bool
DisjointSetForest
<
Index
,
kUseUnionByRank
>::
SameSet
(
Index
element1
,
Index
element2
)
{
return
FindRoot
(
element1
)
==
FindRoot
(
element2
);
}
template
<
class
Index
,
bool
kUseUnionByRank
>
void
DisjointSetForest
<
Index
,
kUseUnionByRank
>::
UnionOfRoots
(
Index
root1
,
Index
root2
)
{
DCHECK_LT
(
root1
,
size
());
DCHECK_LT
(
root2
,
size
());
DCHECK_EQ
(
root1
,
parents_
[
root1
]);
DCHECK_EQ
(
root2
,
parents_
[
root2
]);
if
(
root1
==
root2
)
return
;
// already merged
Index
*
const
__restrict
parents
=
parents_
.
data
();
if
(
kUseUnionByRank
)
{
// Attach the lesser-rank root to the higher-rank root.
Index
*
const
__restrict
ranks
=
ranks_
.
data
();
const
Index
rank1
=
ranks
[
root1
];
const
Index
rank2
=
ranks
[
root2
];
if
(
rank2
<
rank1
)
{
parents
[
root2
]
=
root1
;
}
else
if
(
rank1
<
rank2
)
{
parents
[
root1
]
=
root2
;
}
else
{
// Equal ranks; choose one arbitrarily and promote its rank.
parents
[
root1
]
=
root2
;
ranks
[
root2
]
=
rank2
+
1
;
}
}
else
{
// Always make |root2| the root of the merged set.
parents
[
root1
]
=
root2
;
}
}
template
<
class
Index
,
bool
kUseUnionByRank
>
void
DisjointSetForest
<
Index
,
kUseUnionByRank
>::
Union
(
Index
element1
,
Index
element2
)
{
UnionOfRoots
(
FindRoot
(
element1
),
FindRoot
(
element2
));
}
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_MST_DISJOINT_SET_FOREST_H_
research/syntaxnet/dragnn/mst/disjoint_set_forest_test.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/disjoint_set_forest.h"
#include <stddef.h>
#include <set>
#include <utility>
#include <vector>
#include "syntaxnet/base.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
{
// Testing rig.
//
// Template args:
// Forest: An instantiation of the DisjointSetForest<> template.
template
<
class
Forest
>
class
DisjointSetForestTest
:
public
::
testing
::
Test
{
protected:
using
Index
=
typename
Forest
::
IndexType
;
// Expects that the |expected_sets| and |forest| match.
void
ExpectSets
(
const
std
::
set
<
std
::
set
<
Index
>>
&
expected_sets
,
Forest
*
forest
)
{
std
::
set
<
std
::
pair
<
Index
,
Index
>>
expected_pairs
;
for
(
const
auto
&
expected_set
:
expected_sets
)
{
for
(
auto
it
=
expected_set
.
begin
();
it
!=
expected_set
.
end
();
++
it
)
{
for
(
auto
jt
=
expected_set
.
begin
();
jt
!=
expected_set
.
end
();
++
jt
)
{
expected_pairs
.
emplace
(
*
it
,
*
jt
);
}
}
}
for
(
Index
lhs
=
0
;
lhs
<
forest
->
size
();
++
lhs
)
{
for
(
Index
rhs
=
0
;
rhs
<
forest
->
size
();
++
rhs
)
{
if
(
expected_pairs
.
find
({
lhs
,
rhs
})
!=
expected_pairs
.
end
())
{
EXPECT_EQ
(
forest
->
FindRoot
(
lhs
),
forest
->
FindRoot
(
rhs
));
EXPECT_TRUE
(
forest
->
SameSet
(
lhs
,
rhs
));
}
else
{
EXPECT_NE
(
forest
->
FindRoot
(
lhs
),
forest
->
FindRoot
(
rhs
));
EXPECT_FALSE
(
forest
->
SameSet
(
lhs
,
rhs
));
}
}
}
}
};
using
Forests
=
::
testing
::
Types
<
DisjointSetForest
<
uint8
,
false
>
,
DisjointSetForest
<
uint8
,
true
>
,
DisjointSetForest
<
uint16
,
false
>
,
DisjointSetForest
<
uint16
,
true
>
,
DisjointSetForest
<
uint32
,
false
>
,
DisjointSetForest
<
uint32
,
true
>
,
DisjointSetForest
<
uint64
,
false
>
,
DisjointSetForest
<
uint64
,
true
>>
;
TYPED_TEST_CASE
(
DisjointSetForestTest
,
Forests
);
TYPED_TEST
(
DisjointSetForestTest
,
DefaultEmpty
)
{
TypeParam
forest
;
EXPECT_EQ
(
0
,
forest
.
size
());
}
TYPED_TEST
(
DisjointSetForestTest
,
InitEmpty
)
{
TypeParam
forest
;
forest
.
Init
(
0
);
EXPECT_EQ
(
0
,
forest
.
size
());
}
TYPED_TEST
(
DisjointSetForestTest
,
Populated
)
{
TypeParam
forest
;
forest
.
Init
(
5
);
EXPECT_EQ
(
5
,
forest
.
size
());
this
->
ExpectSets
({{
0
},
{
1
},
{
2
},
{
3
},
{
4
}},
&
forest
);
forest
.
UnionOfRoots
(
1
,
2
);
this
->
ExpectSets
({{
0
},
{
1
,
2
},
{
3
},
{
4
}},
&
forest
);
forest
.
Union
(
1
,
2
);
this
->
ExpectSets
({{
0
},
{
1
,
2
},
{
3
},
{
4
}},
&
forest
);
forest
.
UnionOfRoots
(
0
,
4
);
this
->
ExpectSets
({{
0
,
4
},
{
1
,
2
},
{
3
}},
&
forest
);
forest
.
Union
(
3
,
4
);
this
->
ExpectSets
({{
0
,
3
,
4
},
{
1
,
2
}},
&
forest
);
forest
.
Union
(
0
,
3
);
this
->
ExpectSets
({{
0
,
3
,
4
},
{
1
,
2
}},
&
forest
);
forest
.
Union
(
2
,
0
);
this
->
ExpectSets
({{
0
,
1
,
2
,
3
,
4
}},
&
forest
);
forest
.
Union
(
1
,
3
);
this
->
ExpectSets
({{
0
,
1
,
2
,
3
,
4
}},
&
forest
);
}
// Testing rig for checking that when union by rank is disabled, the root of a
// merged set can be controlled.
class
DisjointSetForestNoUnionByRankTest
:
public
::
testing
::
Test
{
protected:
using
Forest
=
DisjointSetForest
<
uint32
,
false
>
;
// Expects that the roots of the |forest| match |expected_roots|.
void
ExpectRoots
(
const
std
::
vector
<
uint32
>
&
expected_roots
,
Forest
*
forest
)
{
ASSERT_EQ
(
expected_roots
.
size
(),
forest
->
size
());
for
(
uint32
i
=
0
;
i
<
forest
->
size
();
++
i
)
{
EXPECT_EQ
(
expected_roots
[
i
],
forest
->
FindRoot
(
i
));
}
}
};
TEST_F
(
DisjointSetForestNoUnionByRankTest
,
ManuallySpecifyRoot
)
{
Forest
forest
;
forest
.
Init
(
5
);
ExpectRoots
({
0
,
1
,
2
,
3
,
4
},
&
forest
);
forest
.
UnionOfRoots
(
0
,
1
);
// 1 is the root
ExpectRoots
({
1
,
1
,
2
,
3
,
4
},
&
forest
);
forest
.
Union
(
4
,
3
);
// 3 is the root
ExpectRoots
({
1
,
1
,
2
,
3
,
3
},
&
forest
);
forest
.
Union
(
0
,
2
);
// 2 is the root
ExpectRoots
({
2
,
2
,
2
,
3
,
3
},
&
forest
);
forest
.
Union
(
3
,
3
);
// no effect
ExpectRoots
({
2
,
2
,
2
,
3
,
3
},
&
forest
);
forest
.
Union
(
4
,
0
);
// 2 is the root
ExpectRoots
({
2
,
2
,
2
,
2
,
2
},
&
forest
);
}
}
// namespace
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/mst/mst_solver.h
0 → 100644
View file @
80178fc6
This diff is collapsed.
Click to expand it.
research/syntaxnet/dragnn/mst/mst_solver_random_comparison_test.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/mst_solver.h"
#include <time.h>
#include <random>
#include <set>
#include <vector>
#include "dragnn/mst/spanning_tree_iterator.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
{
using
::
testing
::
Contains
;
// Returns the random seed, or 0 for a weak random seed.
int64
GetSeed
()
{
return
1
;
// use a deterministic seed
}
// Returns the number of trials to run for each random comparison.
int64
GetNumTrials
()
{
return
3
;
}
// Testing rig. Runs a comparison between a brute-force MST solver and the
// MstSolver<> on random digraphs. When the first test parameter is true,
// solves for forests instead of trees. The second test parameter defines the
// size of the test digraph.
class
MstSolverRandomComparisonTest
:
public
::
testing
::
TestWithParam
<::
testing
::
tuple
<
bool
,
uint32
>>
{
protected:
// Use integer scores so score comparisons are exact.
using
Solver
=
MstSolver
<
uint32
,
int32
>
;
// An array providing a source node for each node. Roots are self-loops.
using
SourceList
=
SpanningTreeIterator
::
SourceList
;
// A row-major n x n matrix whose i,j entry gives the score of the arc from i
// to j, and whose i,i entry gives the score of selecting i as a root.
using
ScoreMatrix
=
std
::
vector
<
int32
>
;
// Returns true if this should be a forest.
bool
forest
()
const
{
return
::
testing
::
get
<
0
>
(
GetParam
());
}
// Returns the number of nodes for digraphs.
uint32
num_nodes
()
const
{
return
::
testing
::
get
<
1
>
(
GetParam
());
}
// Returns the score of the arcs in |sources| based on the |scores|.
int32
ScoreArcs
(
const
ScoreMatrix
&
scores
,
const
SourceList
&
sources
)
const
{
CHECK_EQ
(
num_nodes
()
*
num_nodes
(),
scores
.
size
());
int32
score
=
0
;
for
(
uint32
target
=
0
;
target
<
num_nodes
();
++
target
)
{
const
uint32
source
=
sources
[
target
];
score
+=
scores
[
target
+
source
*
num_nodes
()];
}
return
score
;
}
// Returns the score of the maximum spanning tree (or forest, if the first
// test parameter is true) of the dense digraph defined by the |scores|, and
// sets |argmax_trees| to contain all maximal trees.
int32
RunBruteForceMstSolver
(
const
ScoreMatrix
&
scores
,
std
::
set
<
SourceList
>
*
argmax_trees
)
{
CHECK_EQ
(
num_nodes
()
*
num_nodes
(),
scores
.
size
());
int32
max_score
;
argmax_trees
->
clear
();
iterator_
.
ForEachTree
(
num_nodes
(),
[
&
](
const
SourceList
&
sources
)
{
const
int32
score
=
ScoreArcs
(
scores
,
sources
);
if
(
argmax_trees
->
empty
()
||
max_score
<
score
)
{
max_score
=
score
;
argmax_trees
->
clear
();
argmax_trees
->
insert
(
sources
);
}
else
if
(
max_score
==
score
)
{
argmax_trees
->
insert
(
sources
);
}
});
return
max_score
;
}
// As above, but uses the |solver_| and extracts only one |argmax_tree|.
int32
RunMstSolver
(
const
ScoreMatrix
&
scores
,
SourceList
*
argmax_tree
)
{
CHECK_EQ
(
num_nodes
()
*
num_nodes
(),
scores
.
size
());
TF_CHECK_OK
(
solver_
.
Init
(
forest
(),
num_nodes
()));
// Add all roots and arcs.
for
(
uint32
source
=
0
;
source
<
num_nodes
();
++
source
)
{
for
(
uint32
target
=
0
;
target
<
num_nodes
();
++
target
)
{
const
int32
score
=
scores
[
target
+
source
*
num_nodes
()];
if
(
source
==
target
)
{
solver_
.
AddRoot
(
target
,
score
);
}
else
{
solver_
.
AddArc
(
source
,
target
,
score
);
}
}
}
// Solve for the max spanning tree.
argmax_tree
->
resize
(
num_nodes
());
TF_CHECK_OK
(
solver_
.
Solve
(
argmax_tree
));
return
ScoreArcs
(
scores
,
*
argmax_tree
);
}
// Returns a random ScoreMatrix spanning num_nodes() nodes.
ScoreMatrix
RandomScores
()
{
ScoreMatrix
scores
(
num_nodes
()
*
num_nodes
());
for
(
int32
&
value
:
scores
)
value
=
static_cast
<
int32
>
(
prng_
()
%
201
)
-
100
;
return
scores
;
}
// Runs a comparison between MstSolver and BruteForceMst on random digraphs of
// num_nodes() nodes, for the specified number of trials.
void
RunComparison
()
{
// Seed the PRNG, possibly non-deterministically. Log the seed value so the
// test results can be reproduced, even when the seed is non-deterministic.
uint32
seed
=
GetSeed
();
if
(
seed
==
0
)
seed
=
time
(
nullptr
);
prng_
.
seed
(
seed
);
LOG
(
INFO
)
<<
"seed = "
<<
seed
;
const
int
num_trials
=
GetNumTrials
();
for
(
int
trial
=
0
;
trial
<
num_trials
;
++
trial
)
{
const
ScoreMatrix
scores
=
RandomScores
();
std
::
set
<
SourceList
>
expected_argmax_trees
;
const
int32
expected_max_score
=
RunBruteForceMstSolver
(
scores
,
&
expected_argmax_trees
);
SourceList
actual_argmax_tree
;
const
int32
actual_max_score
=
RunMstSolver
(
scores
,
&
actual_argmax_tree
);
// In case of ties, MstSolver will find a maximal spanning tree, but we
// don't know which one.
EXPECT_EQ
(
expected_max_score
,
actual_max_score
);
ASSERT_THAT
(
expected_argmax_trees
,
Contains
(
actual_argmax_tree
));
}
}
// Tree iterator for brute-force solver.
SpanningTreeIterator
iterator_
{
forest
()};
// MstSolver<> instance used by the test. Reused across all MST invocations
// to exercise reuse.
Solver
solver_
;
// Pseudo-random number generator.
std
::
mt19937
prng_
;
};
INSTANTIATE_TEST_CASE_P
(
AllowForest
,
MstSolverRandomComparisonTest
,
::
testing
::
Combine
(
::
testing
::
Bool
(),
::
testing
::
Range
<
uint32
>
(
1
,
9
)));
TEST_P
(
MstSolverRandomComparisonTest
,
Comparison
)
{
RunComparison
();
}
}
// namespace
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/mst/mst_solver_test.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/mst_solver.h"
#include <limits>
#include <utility>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "syntaxnet/base.h"
#include <gmock/gmock.h>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
{
using
::
testing
::
HasSubstr
;
// Testing rig.
//
// Template args:
// Solver: An instantiation of the MstSolver<> template.
template
<
class
Solver
>
class
MstSolverTest
:
public
::
testing
::
Test
{
protected:
using
Index
=
typename
Solver
::
IndexType
;
using
Score
=
typename
Solver
::
ScoreType
;
// Adds directed arcs for all |num_nodes| nodes to the |solver_| with the
// |score|.
void
AddAllArcs
(
Index
num_nodes
,
Score
score
)
{
for
(
Index
source
=
0
;
source
<
num_nodes
;
++
source
)
{
for
(
Index
target
=
0
;
target
<
num_nodes
;
++
target
)
{
if
(
source
==
target
)
continue
;
solver_
.
AddArc
(
source
,
target
,
score
);
}
}
}
// Adds root selections for all |num_nodes| nodes to the |solver_| with the
// |score|.
void
AddAllRoots
(
Index
num_nodes
,
Score
score
)
{
for
(
Index
root
=
0
;
root
<
num_nodes
;
++
root
)
{
solver_
.
AddRoot
(
root
,
score
);
}
}
// Runs the |solver_| using an argmax array of size |argmax_array_size| and
// expects it to fail with an error message that matches |error_substr|.
void
SolveAndExpectError
(
int
argmax_array_size
,
const
string
&
error_message_substr
)
{
std
::
vector
<
Index
>
argmax
(
argmax_array_size
);
EXPECT_THAT
(
solver_
.
Solve
(
&
argmax
),
test
::
IsErrorWithSubstr
(
error_message_substr
));
}
// As above, but expects success. Does not assert anything about the solution
// produced by the solver.
void
SolveAndExpectOk
(
int
argmax_array_size
)
{
std
::
vector
<
Index
>
argmax
(
argmax_array_size
);
TF_EXPECT_OK
(
solver_
.
Solve
(
&
argmax
));
}
// As above, but expects the solution to be |expected_argmax| and infers the
// argmax array size.
void
SolveAndExpectArgmax
(
const
std
::
vector
<
Index
>
&
expected_argmax
)
{
std
::
vector
<
Index
>
actual_argmax
(
expected_argmax
.
size
());
TF_ASSERT_OK
(
solver_
.
Solve
(
&
actual_argmax
));
EXPECT_EQ
(
expected_argmax
,
actual_argmax
);
}
// MstSolver<> instance used by the test. Reused across all MST problems in
// each test to exercise reuse.
Solver
solver_
;
};
using
Solvers
=
::
testing
::
Types
<
MstSolver
<
uint8
,
int16
>
,
MstSolver
<
uint16
,
int32
>
,
MstSolver
<
uint32
,
int64
>
,
MstSolver
<
uint16
,
float
>
,
MstSolver
<
uint32
,
double
>>
;
TYPED_TEST_CASE
(
MstSolverTest
,
Solvers
);
TYPED_TEST
(
MstSolverTest
,
FailIfNoNodes
)
{
for
(
const
bool
forest
:
{
false
,
true
})
{
EXPECT_THAT
(
this
->
solver_
.
Init
(
forest
,
0
),
test
::
IsErrorWithSubstr
(
"Non-positive number of nodes"
));
}
}
TYPED_TEST
(
MstSolverTest
,
FailIfTooManyNodes
)
{
// Set to a value that would overflow when doubled.
const
auto
kNumNodes
=
(
std
::
numeric_limits
<
typename
TypeParam
::
IndexType
>::
max
()
/
2
)
+
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
EXPECT_THAT
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
),
test
::
IsErrorWithSubstr
(
"Too many nodes"
));
}
}
TYPED_TEST
(
MstSolverTest
,
InfeasibleIfNoRootsNoArcs
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
SolveAndExpectError
(
kNumNodes
,
"Infeasible digraph"
);
}
}
TYPED_TEST
(
MstSolverTest
,
InfeasibleIfNoRootsAllArcs
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllArcs
(
kNumNodes
,
0
);
this
->
SolveAndExpectError
(
kNumNodes
,
"Infeasible digraph"
);
}
}
TYPED_TEST
(
MstSolverTest
,
FeasibleForForestOnlyIfAllRootsNoArcs
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
if
(
forest
)
{
this
->
SolveAndExpectOk
(
kNumNodes
);
// all roots is a valid forest
}
else
{
this
->
SolveAndExpectError
(
kNumNodes
,
"Infeasible digraph"
);
}
}
}
TYPED_TEST
(
MstSolverTest
,
FeasibleIfAllRootsAllArcs
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
this
->
SolveAndExpectOk
(
kNumNodes
);
}
}
TYPED_TEST
(
MstSolverTest
,
FailIfArgmaxArrayTooSmall
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
this
->
SolveAndExpectError
(
kNumNodes
-
1
,
// too small
"Argmax array too small"
);
}
}
TYPED_TEST
(
MstSolverTest
,
OkIfArgmaxArrayTooLarge
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
this
->
SolveAndExpectOk
(
kNumNodes
+
1
);
// too large
}
}
TYPED_TEST
(
MstSolverTest
,
SolveForAllRootsForestOnly
)
{
const
int
kNumNodes
=
10
;
const
bool
forest
=
true
;
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
1
);
// favor all root selections
this
->
AddAllArcs
(
kNumNodes
,
0
);
this
->
SolveAndExpectArgmax
({
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
});
}
TYPED_TEST
(
MstSolverTest
,
SolveForLeftToRightChain
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
for
(
int
target
=
1
;
target
<
kNumNodes
;
++
target
)
{
this
->
solver_
.
AddArc
(
target
-
1
,
target
,
1
);
// favor left-to-right chain
}
this
->
SolveAndExpectArgmax
({
0
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
});
}
}
TYPED_TEST
(
MstSolverTest
,
SolveForRightToLeftChain
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
for
(
int
source
=
1
;
source
<
kNumNodes
;
++
source
)
{
this
->
solver_
.
AddArc
(
source
,
source
-
1
,
1
);
// favor right-to-left chain
}
this
->
SolveAndExpectArgmax
({
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
9
});
}
}
TYPED_TEST
(
MstSolverTest
,
SolveForAllFromFirstTree
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
for
(
int
target
=
1
;
target
<
kNumNodes
;
++
target
)
{
this
->
solver_
.
AddArc
(
0
,
target
,
1
);
// favor first -> target
}
this
->
SolveAndExpectArgmax
({
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
});
}
}
TYPED_TEST
(
MstSolverTest
,
SolveForAllFromLastTree
)
{
const
int
kNumNodes
=
10
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
for
(
int
target
=
0
;
target
+
1
<
kNumNodes
;
++
target
)
{
this
->
solver_
.
AddArc
(
9
,
target
,
1
);
// favor last -> target
}
this
->
SolveAndExpectArgmax
({
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
,
9
});
}
}
TYPED_TEST
(
MstSolverTest
,
SolveForBinaryTree
)
{
const
int
kNumNodes
=
15
;
for
(
const
bool
forest
:
{
false
,
true
})
{
TF_ASSERT_OK
(
this
->
solver_
.
Init
(
forest
,
kNumNodes
));
this
->
AddAllRoots
(
kNumNodes
,
0
);
this
->
AddAllArcs
(
kNumNodes
,
0
);
for
(
int
target
=
1
;
target
<
kNumNodes
;
++
target
)
{
this
->
solver_
.
AddArc
((
target
-
1
)
/
2
,
target
,
1
);
// like a binary heap
}
this
->
SolveAndExpectArgmax
({
0
,
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
,
4
,
4
,
5
,
5
,
6
,
6
});
}
}
}
// namespace
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/mst/ops/mst_op_kernels.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <cmath>
#include <limits>
#include <type_traits>
#include <vector>
#include "dragnn/mst/mst_solver.h"
#include "syntaxnet/base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
namespace
syntaxnet
{
namespace
dragnn
{
// Op kernel implementation that wraps the |MstSolver|.
template
<
class
Index
,
class
Score
>
class
MaximumSpanningTreeOpKernel
:
public
tensorflow
::
OpKernel
{
public:
explicit
MaximumSpanningTreeOpKernel
(
tensorflow
::
OpKernelConstruction
*
context
)
:
tensorflow
::
OpKernel
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"forest"
,
&
forest_
));
}
void
Compute
(
tensorflow
::
OpKernelContext
*
context
)
override
{
const
tensorflow
::
Tensor
&
num_nodes_tensor
=
context
->
input
(
0
);
const
tensorflow
::
Tensor
&
scores_tensor
=
context
->
input
(
1
);
// Check ranks.
OP_REQUIRES
(
context
,
num_nodes_tensor
.
dims
()
==
1
,
tensorflow
::
errors
::
InvalidArgument
(
"num_nodes must be a vector, got shape "
,
num_nodes_tensor
.
shape
().
DebugString
()));
OP_REQUIRES
(
context
,
scores_tensor
.
dims
()
==
3
,
tensorflow
::
errors
::
InvalidArgument
(
"scores must be rank 3, got shape "
,
scores_tensor
.
shape
().
DebugString
()));
// Batch size and input dimension (B and M in the op docstring).
const
int64
batch_size
=
scores_tensor
.
shape
().
dim_size
(
0
);
const
int64
input_dim
=
scores_tensor
.
shape
().
dim_size
(
1
);
// Check shapes.
const
tensorflow
::
TensorShape
shape_b
({
batch_size
});
const
tensorflow
::
TensorShape
shape_bxm
({
batch_size
,
input_dim
});
const
tensorflow
::
TensorShape
shape_bxmxm
(
{
batch_size
,
input_dim
,
input_dim
});
OP_REQUIRES
(
context
,
num_nodes_tensor
.
shape
()
==
shape_b
,
tensorflow
::
errors
::
InvalidArgument
(
"num_nodes misshapen: got "
,
num_nodes_tensor
.
shape
().
DebugString
(),
" but expected "
,
shape_b
.
DebugString
()));
OP_REQUIRES
(
context
,
scores_tensor
.
shape
()
==
shape_bxmxm
,
tensorflow
::
errors
::
InvalidArgument
(
"scores misshapen: got "
,
scores_tensor
.
shape
().
DebugString
(),
" but expected "
,
shape_bxmxm
.
DebugString
()));
// Create outputs.
tensorflow
::
Tensor
*
max_scores_tensor
=
nullptr
;
tensorflow
::
Tensor
*
argmax_sources_tensor
=
nullptr
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
0
,
shape_b
,
&
max_scores_tensor
));
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
1
,
shape_bxm
,
&
argmax_sources_tensor
));
// Acquire shaped and typed references.
const
BatchedSizes
num_nodes_b
=
num_nodes_tensor
.
vec
<
int32
>
();
const
BatchedScores
scores_bxmxm
=
scores_tensor
.
tensor
<
Score
,
3
>
();
BatchedMaxima
max_scores_b
=
max_scores_tensor
->
vec
<
Score
>
();
BatchedSources
argmax_sources_bxm
=
argmax_sources_tensor
->
matrix
<
int32
>
();
// Solve the batch of MST problems in parallel. Set a high cycles per unit
// to encourage finer sharding.
constexpr
int64
kCyclesPerUnit
=
1000
*
1000
*
1000
;
std
::
vector
<
tensorflow
::
Status
>
statuses
(
batch_size
);
context
->
device
()
->
tensorflow_cpu_worker_threads
()
->
workers
->
ParallelFor
(
batch_size
,
kCyclesPerUnit
,
[
&
](
int64
begin
,
int64
end
)
{
for
(
int64
problem
=
begin
;
problem
<
end
;
++
problem
)
{
statuses
[
problem
]
=
RunSolver
(
problem
,
num_nodes_b
,
scores_bxmxm
,
max_scores_b
,
argmax_sources_bxm
);
}
});
for
(
const
tensorflow
::
Status
&
status
:
statuses
)
{
OP_REQUIRES_OK
(
context
,
status
);
}
}
private:
using
BatchedSizes
=
typename
tensorflow
::
TTypes
<
int32
>::
ConstVec
;
using
BatchedScores
=
typename
tensorflow
::
TTypes
<
Score
,
3
>::
ConstTensor
;
using
BatchedMaxima
=
typename
tensorflow
::
TTypes
<
Score
>::
Vec
;
using
BatchedSources
=
typename
tensorflow
::
TTypes
<
int32
>::
Matrix
;
// Solves for the maximum spanning tree of the digraph defined by the values
// at index |problem| in |num_nodes_b| and |scores_bxmxm|. On success, sets
// the values at index |problem| in |max_scores_b| and |argmax_sources_bxm|.
// On error, returns non-OK.
tensorflow
::
Status
RunSolver
(
int
problem
,
BatchedSizes
num_nodes_b
,
BatchedScores
scores_bxmxm
,
BatchedMaxima
max_scores_b
,
BatchedSources
argmax_sources_bxm
)
const
{
// Check digraph size overflow.
const
int32
num_nodes
=
num_nodes_b
(
problem
);
const
int32
input_dim
=
argmax_sources_bxm
.
dimension
(
1
);
if
(
num_nodes
>
input_dim
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"number of nodes in digraph "
,
problem
,
" overflows input dimension: got "
,
num_nodes
,
" but expected <= "
,
input_dim
);
}
if
(
num_nodes
>=
std
::
numeric_limits
<
Index
>::
max
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
"number of nodes in digraph "
,
problem
,
" overflows index type: got "
,
num_nodes
,
" but expected < "
,
std
::
numeric_limits
<
Index
>::
max
());
}
const
Index
num_nodes_index
=
static_cast
<
Index
>
(
num_nodes
);
MstSolver
<
Index
,
Score
>
solver
;
TF_RETURN_IF_ERROR
(
solver
.
Init
(
forest_
,
num_nodes_index
));
// Populate the solver with arcs and root selections. Note that non-finite
// scores are treated as nonexistent arcs or roots.
for
(
Index
target
=
0
;
target
<
num_nodes_index
;
++
target
)
{
for
(
Index
source
=
0
;
source
<
num_nodes_index
;
++
source
)
{
const
Score
score
=
scores_bxmxm
(
problem
,
target
,
source
);
if
(
!
std
::
isfinite
(
score
))
continue
;
if
(
source
==
target
)
{
// root
solver
.
AddRoot
(
target
,
score
);
}
else
{
// arc
solver
.
AddArc
(
source
,
target
,
score
);
}
}
}
std
::
vector
<
Index
>
argmax
(
num_nodes
);
TF_RETURN_IF_ERROR
(
solver
.
Solve
(
&
argmax
));
// Output the tree and accumulate its score.
Score
max_score
=
0
;
for
(
Index
target
=
0
;
target
<
num_nodes_index
;
++
target
)
{
const
Index
source
=
argmax
[
target
];
argmax_sources_bxm
(
problem
,
target
)
=
source
;
max_score
+=
scores_bxmxm
(
problem
,
target
,
source
);
}
max_scores_b
(
problem
)
=
max_score
;
// Pad the source list with -1.
for
(
int32
i
=
num_nodes
;
i
<
input_dim
;
++
i
)
{
argmax_sources_bxm
(
problem
,
i
)
=
-
1
;
}
return
tensorflow
::
Status
::
OK
();
}
private:
bool
forest_
=
false
;
};
// Use Index=uint16, which allows digraphs containing up to 32,767 nodes.
REGISTER_KERNEL_BUILDER
(
Name
(
"MaximumSpanningTree"
)
.
Device
(
tensorflow
::
DEVICE_CPU
)
.
TypeConstraint
<
int32
>
(
"T"
),
MaximumSpanningTreeOpKernel
<
uint16
,
int32
>
);
REGISTER_KERNEL_BUILDER
(
Name
(
"MaximumSpanningTree"
)
.
Device
(
tensorflow
::
DEVICE_CPU
)
.
TypeConstraint
<
float
>
(
"T"
),
MaximumSpanningTreeOpKernel
<
uint16
,
float
>
);
REGISTER_KERNEL_BUILDER
(
Name
(
"MaximumSpanningTree"
)
.
Device
(
tensorflow
::
DEVICE_CPU
)
.
TypeConstraint
<
double
>
(
"T"
),
MaximumSpanningTreeOpKernel
<
uint16
,
double
>
);
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/mst/ops/mst_ops.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace
syntaxnet
{
namespace
dragnn
{
REGISTER_OP
(
"MaximumSpanningTree"
)
.
Attr
(
"T: {int32, float, double}"
)
.
Attr
(
"forest: bool = false"
)
.
Input
(
"num_nodes: int32"
)
.
Input
(
"scores: T"
)
.
Output
(
"max_scores: T"
)
.
Output
(
"argmax_sources: int32"
)
.
SetShapeFn
([](
tensorflow
::
shape_inference
::
InferenceContext
*
context
)
{
tensorflow
::
shape_inference
::
ShapeHandle
num_nodes
;
tensorflow
::
shape_inference
::
ShapeHandle
scores
;
TF_RETURN_IF_ERROR
(
context
->
WithRank
(
context
->
input
(
0
),
1
,
&
num_nodes
));
TF_RETURN_IF_ERROR
(
context
->
WithRank
(
context
->
input
(
1
),
3
,
&
scores
));
// Extract dimensions while asserting that they match.
tensorflow
::
shape_inference
::
DimensionHandle
batch_size
;
// aka "B"
TF_RETURN_IF_ERROR
(
context
->
Merge
(
context
->
Dim
(
num_nodes
,
0
),
context
->
Dim
(
scores
,
0
),
&
batch_size
));
tensorflow
::
shape_inference
::
DimensionHandle
max_nodes
;
// aka "M"
TF_RETURN_IF_ERROR
(
context
->
Merge
(
context
->
Dim
(
scores
,
1
),
context
->
Dim
(
scores
,
2
),
&
max_nodes
));
context
->
set_output
(
0
,
context
->
Vector
(
batch_size
));
context
->
set_output
(
1
,
context
->
Matrix
(
batch_size
,
max_nodes
));
return
tensorflow
::
Status
::
OK
();
})
.
Doc
(
R"doc(
Finds the maximum directed spanning tree of a digraph.
Given a batch of digraphs with scored arcs and root selections, solves for the
maximum spanning tree of each digraph, where the score of a tree is defined as
the sum of the scores of the arcs and roots making up the tree.
Returns the score of the maximum spanning tree of each digraph, as well as the
arcs and roots in that tree. Each digraph in a batch may contain a different
number of nodes, so the sizes of the digraphs must be provided as an input.
Note that this operation is only differentiable w.r.t. its |scores| input and
its |max_scores| output.
forest: If true, solves for a maximum spanning forest instead of a maximum
spanning tree, where a spanning forest is a set of disjoint trees that
span the nodes of the digraph.
num_nodes: [B] vector where entry b is number of nodes in the b'th digraph.
scores: [B,M,M] tensor where entry b,t,s is the score of the arc from s to t in
the b'th digraph, if s!=t, or the score of selecting t as a root in the
b'th digraph, if s==t. Requires that M is >= num_nodes[b], for all b,
and ignores entries b,s,t where s or t is >= num_nodes[b]. Arcs or root
selections with non-finite score are treated as nonexistent.
max_scores: [B] vector where entry b is the score of the maximum spanning tree
of the b'th digraph.
argmax_sources: [B,M] matrix where entry b,t is the source of the arc inbound to
t in the maximum spanning tree of the b'th digraph, or t if t is
a root. Entries b,t where t is >= num_nodes[b] are set to -1.
)doc"
);
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/mst/spanning_tree_iterator.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/spanning_tree_iterator.h"
namespace
syntaxnet
{
namespace
dragnn
{
SpanningTreeIterator
::
SpanningTreeIterator
(
bool
forest
)
:
forest_
(
forest
)
{}
bool
SpanningTreeIterator
::
HasCycle
(
const
SourceList
&
sources
)
{
// Flags for whether each node has already been searched.
searched_
.
assign
(
sources
.
size
(),
false
);
// Flags for whether the search is currently visiting each node.
visiting_
.
assign
(
sources
.
size
(),
false
);
// Search upwards from each node to find cycles.
for
(
uint32
initial_node
=
0
;
initial_node
<
sources
.
size
();
++
initial_node
)
{
// Search upwards to try to find a cycle.
uint32
current_node
=
initial_node
;
while
(
true
)
{
if
(
searched_
[
current_node
])
break
;
// already searched
if
(
visiting_
[
current_node
])
return
true
;
// revisiting implies cycle
visiting_
[
current_node
]
=
true
;
// mark as being currently visited
const
uint32
source_node
=
sources
[
current_node
];
if
(
source_node
==
current_node
)
break
;
// self-loops are roots
current_node
=
source_node
;
// advance upwards
}
// No cycle; search upwards again to update flags.
current_node
=
initial_node
;
while
(
true
)
{
if
(
searched_
[
current_node
])
break
;
// already searched
searched_
[
current_node
]
=
true
;
visiting_
[
current_node
]
=
false
;
const
uint32
source_node
=
sources
[
current_node
];
if
(
source_node
==
current_node
)
break
;
// self-loops are roots
current_node
=
source_node
;
// advance upwards
}
}
return
false
;
}
uint32
SpanningTreeIterator
::
NumRoots
(
const
SourceList
&
sources
)
{
uint32
num_roots
=
0
;
for
(
uint32
node
=
0
;
node
<
sources
.
size
();
++
node
)
{
num_roots
+=
(
node
==
sources
[
node
]);
}
return
num_roots
;
}
bool
SpanningTreeIterator
::
NextSourceList
(
SourceList
*
sources
)
{
const
uint32
num_nodes
=
sources
->
size
();
for
(
uint32
i
=
0
;
i
<
num_nodes
;
++
i
)
{
const
uint32
new_source
=
++
(
*
sources
)[
i
];
if
(
new_source
<
num_nodes
)
return
true
;
// absorbed in this digit
(
*
sources
)[
i
]
=
0
;
// overflowed this digit, carry to next digit
}
return
false
;
// overflowed the last digit
}
bool
SpanningTreeIterator
::
NextTree
(
SourceList
*
sources
)
{
// Iterate source lists, skipping non-trees.
while
(
NextSourceList
(
sources
))
{
// Check the number of roots.
const
uint32
num_roots
=
NumRoots
(
*
sources
);
if
(
forest_
)
{
if
(
num_roots
==
0
)
continue
;
}
else
{
if
(
num_roots
!=
1
)
continue
;
}
// Check for cycles.
if
(
HasCycle
(
*
sources
))
continue
;
// Acyclic and rooted, therefore tree.
return
true
;
}
return
false
;
}
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/mst/spanning_tree_iterator.h
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef DRAGNN_MST_SPANNING_TREE_ITERATOR_H_
#define DRAGNN_MST_SPANNING_TREE_ITERATOR_H_
#include <vector>
#include "syntaxnet/base.h"
namespace
syntaxnet
{
namespace
dragnn
{
// A class that iterates over all possible spanning trees of a complete digraph.
// Thread-compatible. Useful for brute-force comparison tests.
//
// TODO(googleuser): Try using Prufer sequences, which are more efficient to
// enumerate as there are no non-trees to filter out.
class
SpanningTreeIterator
{
public:
// An array that provides the source of the inbound arc for each node. Roots
// are represented as self-loops.
using
SourceList
=
std
::
vector
<
uint32
>
;
// Creates a spanning tree iterator. If |forest| is true, then this iterates
// over forests instead of trees (i.e., multiple roots are allowed).
explicit
SpanningTreeIterator
(
bool
forest
);
// Applies the |functor| to all spanning trees (or forests, if |forest_| is
// true) of a complete digraph containing |num_nodes| nodes. Each tree is
// passed to the |functor| as a SourceList.
template
<
class
Functor
>
void
ForEachTree
(
uint32
num_nodes
,
Functor
functor
)
{
// Conveniently, the all-zero vector represents a valid tree.
SourceList
sources
(
num_nodes
,
0
);
do
{
functor
(
sources
);
}
while
(
NextTree
(
&
sources
));
}
private:
// Returns true if the |sources| contains a cycle.
bool
HasCycle
(
const
SourceList
&
sources
);
// Returns the number of roots in the |sources|.
static
uint32
NumRoots
(
const
SourceList
&
sources
);
// Advances |sources| to the next source list, or returns false if there are
// no more source lists.
static
bool
NextSourceList
(
SourceList
*
sources
);
// Advances |sources| to the next tree (or forest, if |forest_| is true), or
// returns false if there are no more trees.
bool
NextTree
(
SourceList
*
sources
);
// If true, iterate over spanning forests instead of spanning trees.
const
bool
forest_
;
// Workspaces used by the search in HasCycle().
std
::
vector
<
bool
>
searched_
;
std
::
vector
<
bool
>
visiting_
;
};
}
// namespace dragnn
}
// namespace syntaxnet
#endif // DRAGNN_MST_SPANNING_TREE_ITERATOR_H_
research/syntaxnet/dragnn/mst/spanning_tree_iterator_test.cc
0 → 100644
View file @
80178fc6
// Copyright 2018 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/mst/spanning_tree_iterator.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace
syntaxnet
{
namespace
dragnn
{
namespace
{
// Testing rig. When the bool parameter is true, iterates over spanning forests
// instead of spanning trees.
class
SpanningTreeIteratorTest
:
public
::
testing
::
TestWithParam
<
bool
>
{
protected:
using
SourceList
=
SpanningTreeIterator
::
SourceList
;
// Returns |base|^|exponent|. Computes the value as an integer to avoid
// rounding issues.
static
int
Pow
(
int
base
,
int
exponent
)
{
double
real_product
=
1.0
;
int
product
=
1
;
for
(
int
i
=
0
;
i
<
exponent
;
++
i
)
{
product
*=
base
;
real_product
*=
base
;
}
CHECK_EQ
(
product
,
real_product
)
<<
"Overflow detected."
;
return
product
;
}
// Expects that the number of possible spanning trees for a complete digraph
// of |num_nodes| nodes is |expected_num_trees|.
void
ExpectNumTrees
(
int
num_nodes
,
int
expected_num_trees
)
{
int
actual_num_trees
=
0
;
iterator_
.
ForEachTree
(
num_nodes
,
[
&
](
const
SourceList
&
sources
)
{
++
actual_num_trees
;
});
LOG
(
INFO
)
<<
"num_nodes="
<<
num_nodes
<<
" expected_num_trees="
<<
expected_num_trees
<<
" actual_num_trees="
<<
actual_num_trees
;
EXPECT_EQ
(
expected_num_trees
,
actual_num_trees
);
}
// Expects that the set of possible spanning trees for a complete digraph of
// |num_nodes| nodes is |expected_trees|.
void
ExpectTrees
(
int
num_nodes
,
const
std
::
set
<
SourceList
>
&
expected_trees
)
{
std
::
set
<
SourceList
>
actual_trees
;
iterator_
.
ForEachTree
(
num_nodes
,
[
&
](
const
SourceList
&
sources
)
{
CHECK
(
actual_trees
.
insert
(
sources
).
second
);
});
EXPECT_EQ
(
expected_trees
,
actual_trees
);
}
// Instance for tests. Shared across assertions in a test to exercise reuse.
SpanningTreeIterator
iterator_
{
GetParam
()};
};
INSTANTIATE_TEST_CASE_P
(
AllowForest
,
SpanningTreeIteratorTest
,
::
testing
::
Bool
());
TEST_P
(
SpanningTreeIteratorTest
,
NumberOfTrees
)
{
// According to Cayley's formula, the number of undirected spanning trees on a
// complete graph of n nodes is n^{n-2}:
// https://en.wikipedia.org/wiki/Cayley%27s_formula
//
// To count the number of directed spanning trees, note that each undirected
// spanning tree gives rise to n directed spanning trees: choose one of the n
// nodes as the root, and then orient arcs outwards. Therefore, the number of
// directed spanning trees on a complete digraph of n nodes is n^{n-1}.
//
// To count the number of directed spanning forests, consider undirected
// spanning trees on a complete graph of n+1 nodes. Arbitrarily select one
// node as the artificial root, orient arcs outwards, and then delete the
// artificial root and its outbound arcs. The result is a directed spanning
// forest on n nodes. Therefore, the number of directed spanning forests on a
// complete digraph of n nodes is (n+1)^{n-1}.
for
(
int
num_nodes
=
1
;
num_nodes
<=
7
;
++
num_nodes
)
{
if
(
GetParam
())
{
// forest
ExpectNumTrees
(
num_nodes
,
Pow
(
num_nodes
+
1
,
num_nodes
-
1
));
}
else
{
// tree
ExpectNumTrees
(
num_nodes
,
Pow
(
num_nodes
,
num_nodes
-
1
));
}
}
}
TEST_P
(
SpanningTreeIteratorTest
,
OneNodeDigraph
)
{
ExpectTrees
(
1
,
{{
0
}});
}
TEST_P
(
SpanningTreeIteratorTest
,
TwoNodeDigraph
)
{
if
(
GetParam
())
{
// forest
ExpectTrees
(
2
,
{{
0
,
0
},
{
0
,
1
},
{
1
,
1
}});
// {0, 1} is two-root structure
}
else
{
// tree
ExpectTrees
(
2
,
{{
0
,
0
},
{
1
,
1
}});
}
}
TEST_P
(
SpanningTreeIteratorTest
,
ThreeNodeDigraph
)
{
if
(
GetParam
())
{
// forest
ExpectTrees
(
3
,
{{
0
,
0
,
0
},
{
0
,
0
,
1
},
{
0
,
0
,
2
},
// 2-root
{
0
,
1
,
0
},
// 2-root
{
0
,
1
,
1
},
// 2-root
{
0
,
1
,
2
},
// 3-root
{
0
,
2
,
0
},
{
0
,
2
,
2
},
// 2-root
{
1
,
1
,
0
},
{
1
,
1
,
1
},
{
1
,
1
,
2
},
// 2-root
{
1
,
2
,
2
},
{
2
,
0
,
2
},
{
2
,
1
,
1
},
{
2
,
1
,
2
},
// 2-root
{
2
,
2
,
2
}});
}
else
{
// tree
ExpectTrees
(
3
,
{{
0
,
0
,
0
},
{
0
,
0
,
1
},
{
0
,
2
,
0
},
{
1
,
1
,
0
},
{
1
,
1
,
1
},
{
1
,
2
,
2
},
{
2
,
0
,
2
},
{
2
,
1
,
1
},
{
2
,
2
,
2
}});
}
}
}
// namespace
}
// namespace dragnn
}
// namespace syntaxnet
research/syntaxnet/dragnn/protos/BUILD
View file @
80178fc6
...
@@ -2,48 +2,63 @@ package(default_visibility = ["//visibility:public"])
...
@@ -2,48 +2,63 @@ package(default_visibility = ["//visibility:public"])
load
(
load
(
"//syntaxnet:syntaxnet.bzl"
,
"//syntaxnet:syntaxnet.bzl"
,
"tf_proto_library"
,
"tf_proto_library
_cc
"
,
"tf_proto_library_py"
,
"tf_proto_library_py"
,
)
)
# Protos.
# Protos.
tf_proto_library
(
tf_proto_library
_cc
(
name
=
"data_proto"
,
name
=
"data_proto"
,
srcs
=
[
"data.proto"
],
srcs
=
[
"data.proto"
],
)
)
tf_proto_library
(
tf_proto_library
_cc
(
name
=
"trace_proto"
,
name
=
"trace_proto"
,
srcs
=
[
"trace.proto"
],
srcs
=
[
"trace.proto"
],
deps
=
[
protodeps
=
[
":data_proto"
],
":data_proto"
,
],
)
)
tf_proto_library
(
tf_proto_library_cc
(
name
=
"cell_trace_proto"
,
srcs
=
[
"cell_trace.proto"
],
protodeps
=
[
":trace_proto"
],
)
tf_proto_library_cc
(
name
=
"spec_proto"
,
name
=
"spec_proto"
,
srcs
=
[
"spec.proto"
],
srcs
=
[
"spec.proto"
],
)
)
tf_proto_library
(
tf_proto_library
_cc
(
name
=
"runtime_proto"
,
name
=
"runtime_proto"
,
srcs
=
[
"runtime.proto"
],
srcs
=
[
"runtime.proto"
],
deps
=
[
":spec_proto"
],
protodeps
=
[
":spec_proto"
],
)
tf_proto_library_cc
(
name
=
"export_proto"
,
srcs
=
[
"export.proto"
],
protodeps
=
[
":spec_proto"
],
)
)
tf_proto_library_py
(
tf_proto_library_py
(
name
=
"data_
py_
pb2"
,
name
=
"data_pb2"
,
srcs
=
[
"data.proto"
],
srcs
=
[
"data.proto"
],
)
)
tf_proto_library_py
(
tf_proto_library_py
(
name
=
"trace_
py_
pb2"
,
name
=
"trace_pb2"
,
srcs
=
[
"trace.proto"
],
srcs
=
[
"trace.proto"
],
deps
=
[
":data_
py_
pb2"
],
proto
deps
=
[
":data_pb2"
],
)
)
tf_proto_library_py
(
tf_proto_library_py
(
name
=
"spec_
py_
pb2"
,
name
=
"spec_pb2"
,
srcs
=
[
"spec.proto"
],
srcs
=
[
"spec.proto"
],
)
)
tf_proto_library_py
(
name
=
"export_pb2"
,
srcs
=
[
"export.proto"
],
)
research/syntaxnet/dragnn/protos/cell_trace.proto
0 → 100644
View file @
80178fc6
This diff is collapsed.
Click to expand it.
research/syntaxnet/dragnn/protos/data.proto
View file @
80178fc6
// DRAGNN data proto. See go/dragnn-design for more information.
// DRAGNN data proto.
syntax
=
"proto2"
;
syntax
=
"proto2"
;
...
...
research/syntaxnet/dragnn/protos/export.proto
0 → 100644
View file @
80178fc6
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment