Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
80178fc6
Unverified
Commit
80178fc6
authored
May 11, 2018
by
Mark Omernick
Committed by
GitHub
May 11, 2018
Browse files
Merge pull request #4153 from terryykoo/master
Export @195097388.
parents
a84e1ef9
edea2b67
Changes
145
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1716 additions
and
386 deletions
+1716
-386
research/syntaxnet/dragnn/protos/runtime.proto
research/syntaxnet/dragnn/protos/runtime.proto
+1
-1
research/syntaxnet/dragnn/protos/spec.proto
research/syntaxnet/dragnn/protos/spec.proto
+23
-3
research/syntaxnet/dragnn/protos/trace.proto
research/syntaxnet/dragnn/protos/trace.proto
+2
-0
research/syntaxnet/dragnn/python/BUILD
research/syntaxnet/dragnn/python/BUILD
+160
-26
research/syntaxnet/dragnn/python/biaffine_units.py
research/syntaxnet/dragnn/python/biaffine_units.py
+36
-11
research/syntaxnet/dragnn/python/biaffine_units_test.py
research/syntaxnet/dragnn/python/biaffine_units_test.py
+151
-0
research/syntaxnet/dragnn/python/bulk_component.py
research/syntaxnet/dragnn/python/bulk_component.py
+114
-35
research/syntaxnet/dragnn/python/bulk_component_test.py
research/syntaxnet/dragnn/python/bulk_component_test.py
+222
-165
research/syntaxnet/dragnn/python/component.py
research/syntaxnet/dragnn/python/component.py
+216
-43
research/syntaxnet/dragnn/python/component_test.py
research/syntaxnet/dragnn/python/component_test.py
+153
-0
research/syntaxnet/dragnn/python/digraph_ops.py
research/syntaxnet/dragnn/python/digraph_ops.py
+132
-12
research/syntaxnet/dragnn/python/digraph_ops_test.py
research/syntaxnet/dragnn/python/digraph_ops_test.py
+134
-23
research/syntaxnet/dragnn/python/dragnn_model_saver.py
research/syntaxnet/dragnn/python/dragnn_model_saver.py
+11
-7
research/syntaxnet/dragnn/python/dragnn_model_saver_lib.py
research/syntaxnet/dragnn/python/dragnn_model_saver_lib.py
+3
-1
research/syntaxnet/dragnn/python/dragnn_model_saver_lib_test.py
...ch/syntaxnet/dragnn/python/dragnn_model_saver_lib_test.py
+213
-14
research/syntaxnet/dragnn/python/file_diff_test.py
research/syntaxnet/dragnn/python/file_diff_test.py
+48
-0
research/syntaxnet/dragnn/python/graph_builder.py
research/syntaxnet/dragnn/python/graph_builder.py
+69
-22
research/syntaxnet/dragnn/python/graph_builder_test.py
research/syntaxnet/dragnn/python/graph_builder_test.py
+3
-12
research/syntaxnet/dragnn/python/lexicon_test.py
research/syntaxnet/dragnn/python/lexicon_test.py
+3
-11
research/syntaxnet/dragnn/python/load_mst_cc_impl.py
research/syntaxnet/dragnn/python/load_mst_cc_impl.py
+22
-0
No files found.
research/syntaxnet/dragnn/protos/runtime.proto
View file @
80178fc6
...
@@ -16,7 +16,7 @@ message MasterPerformanceSettings {
...
@@ -16,7 +16,7 @@ message MasterPerformanceSettings {
// Maximum size of the free list in the SessionStatePool. NB: The default
// Maximum size of the free list in the SessionStatePool. NB: The default
// value may occasionally change.
// value may occasionally change.
optional
uint64
session_state_pool_max_free_states
=
1
[
default
=
4
];
optional
uint64
session_state_pool_max_free_states
=
1
[
default
=
16
];
}
}
// As above, but for component-specific performance tuning settings.
// As above, but for component-specific performance tuning settings.
...
...
research/syntaxnet/dragnn/protos/spec.proto
View file @
80178fc6
// DRAGNN Configuration proto. See go/dragnn-design for more information.
// DRAGNN Configuration proto.
syntax
=
"proto2"
;
syntax
=
"proto2"
;
...
@@ -93,7 +94,7 @@ message Part {
...
@@ -93,7 +94,7 @@ message Part {
// are extracted, embedded, and then concatenated together as a group.
// are extracted, embedded, and then concatenated together as a group.
// Specification for a feature channel that is a *fixed* function of the input.
// Specification for a feature channel that is a *fixed* function of the input.
// NEXT_ID: 1
0
// NEXT_ID: 1
2
message
FixedFeatureChannel
{
message
FixedFeatureChannel
{
// Interpretable name for this feature channel. NN builders might depend on
// Interpretable name for this feature channel. NN builders might depend on
// this to determine how to hook different channels up internally.
// this to determine how to hook different channels up internally.
...
@@ -129,6 +130,19 @@ message FixedFeatureChannel {
...
@@ -129,6 +130,19 @@ message FixedFeatureChannel {
// Vocab file, containing all vocabulary words one per line.
// Vocab file, containing all vocabulary words one per line.
optional
Resource
vocab
=
8
;
optional
Resource
vocab
=
8
;
// Settings for feature ID dropout:
// If non-negative, enables feature ID dropout, and dropped feature IDs will
// be replaced with this ID.
optional
int64
dropout_id
=
10
[
default
=
-
1
];
// Probability of keeping each of the |vocabulary_size| feature IDs. Only
// used if |dropout_id| is non-negative, and must not be empty in that case.
// If this has fewer than |vocabulary_size| values, then the final value is
// tiled onto the remaining IDs. For example, specifying a single value is
// equivalent to setting all IDs to that value.
repeated
float
dropout_keep_probability
=
11
[
packed
=
true
];
}
}
// Specification for a feature channel that *links* to component
// Specification for a feature channel that *links* to component
...
@@ -173,11 +187,17 @@ message TrainingGridSpec {
...
@@ -173,11 +187,17 @@ message TrainingGridSpec {
}
}
// A hyperparameter configuration for a training run.
// A hyperparameter configuration for a training run.
// NEXT ID: 2
2
// NEXT ID: 2
3
message
GridPoint
{
message
GridPoint
{
// Global learning rate initialization point.
// Global learning rate initialization point.
optional
double
learning_rate
=
1
[
default
=
0.1
];
optional
double
learning_rate
=
1
[
default
=
0.1
];
// Whether to use PBT (population-based training) to optimize the learning
// rate. Population-based training is not currently open-source, so this will
// just create a tf.assign op which external frameworks can use to adjust the
// learning rate.
optional
bool
pbt_optimize_learning_rate
=
22
[
default
=
false
];
// Momentum coefficient when using MomentumOptimizer.
// Momentum coefficient when using MomentumOptimizer.
optional
double
momentum
=
2
[
default
=
0.9
];
optional
double
momentum
=
2
[
default
=
0.9
];
...
...
research/syntaxnet/dragnn/protos/trace.proto
View file @
80178fc6
...
@@ -53,6 +53,8 @@ message ComponentStepTrace {
...
@@ -53,6 +53,8 @@ message ComponentStepTrace {
// Set to true once the step is finished. (This allows us to open a step after
// Set to true once the step is finished. (This allows us to open a step after
// each transition, without having to know if it will be used.)
// each transition, without having to know if it will be used.)
optional
bool
step_finished
=
6
[
default
=
false
];
optional
bool
step_finished
=
6
[
default
=
false
];
extensions
1000
to
max
;
}
}
// The traces for all steps for a single Component.
// The traces for all steps for a single Component.
...
...
research/syntaxnet/dragnn/python/BUILD
View file @
80178fc6
...
@@ -16,6 +16,17 @@ cc_binary(
...
@@ -16,6 +16,17 @@ cc_binary(
],
],
)
)
cc_binary
(
name
=
"mst_cc_impl.so"
,
linkopts
=
select
({
"//conditions:default"
:
[
"-lm"
],
"@org_tensorflow//tensorflow:darwin"
:
[],
}),
linkshared
=
1
,
linkstatic
=
1
,
deps
=
[
"//dragnn/mst:mst_ops_cc"
],
)
filegroup
(
filegroup
(
name
=
"testdata"
,
name
=
"testdata"
,
data
=
glob
([
"testdata/**"
]),
data
=
glob
([
"testdata/**"
]),
...
@@ -27,6 +38,12 @@ py_library(
...
@@ -27,6 +38,12 @@ py_library(
data
=
[
":dragnn_cc_impl.so"
],
data
=
[
":dragnn_cc_impl.so"
],
)
)
py_library
(
name
=
"load_mst_cc_impl_py"
,
srcs
=
[
"load_mst_cc_impl.py"
],
data
=
[
":mst_cc_impl.so"
],
)
py_library
(
py_library
(
name
=
"bulk_component"
,
name
=
"bulk_component"
,
srcs
=
[
srcs
=
[
...
@@ -50,6 +67,8 @@ py_library(
...
@@ -50,6 +67,8 @@ py_library(
":bulk_component"
,
":bulk_component"
,
":dragnn_ops"
,
":dragnn_ops"
,
":network_units"
,
":network_units"
,
":runtime_support"
,
"//dragnn/protos:export_pb2_py"
,
"//syntaxnet/util:check"
,
"//syntaxnet/util:check"
,
"//syntaxnet/util:pyregistry"
,
"//syntaxnet/util:pyregistry"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
...
@@ -85,9 +104,9 @@ py_library(
...
@@ -85,9 +104,9 @@ py_library(
":graph_builder"
,
":graph_builder"
,
":load_dragnn_cc_impl_py"
,
":load_dragnn_cc_impl_py"
,
":network_units"
,
":network_units"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:sentence_p
y_pb2
"
,
"//syntaxnet:sentence_p
b2_py
"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
],
...
@@ -99,7 +118,9 @@ py_test(
...
@@ -99,7 +118,9 @@ py_test(
data
=
[
":testdata"
],
data
=
[
":testdata"
],
deps
=
[
deps
=
[
":dragnn_model_saver_lib"
,
":dragnn_model_saver_lib"
,
"//dragnn/protos:spec_py_pb2"
,
"//dragnn/protos:export_pb2_py"
,
"//dragnn/protos:spec_pb2_py"
,
"//syntaxnet:test_flags"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
],
)
)
...
@@ -110,7 +131,9 @@ py_binary(
...
@@ -110,7 +131,9 @@ py_binary(
deps
=
[
deps
=
[
":dragnn_model_saver_lib"
,
":dragnn_model_saver_lib"
,
":spec_builder"
,
":spec_builder"
,
"//dragnn/protos:spec_py_pb2"
,
"//dragnn/protos:spec_pb2_py"
,
"@absl_py//absl:app"
,
"@absl_py//absl/flags"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
],
...
@@ -127,7 +150,7 @@ py_library(
...
@@ -127,7 +150,7 @@ py_library(
":network_units"
,
":network_units"
,
":transformer_units"
,
":transformer_units"
,
":wrapped_units"
,
":wrapped_units"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet/util:check"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
...
@@ -159,7 +182,7 @@ py_test(
...
@@ -159,7 +182,7 @@ py_test(
srcs
=
[
"render_parse_tree_graphviz_test.py"
],
srcs
=
[
"render_parse_tree_graphviz_test.py"
],
deps
=
[
deps
=
[
":render_parse_tree_graphviz"
,
":render_parse_tree_graphviz"
,
"//syntaxnet:sentence_p
y_pb2
"
,
"//syntaxnet:sentence_p
b2_py
"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
],
)
)
...
@@ -168,7 +191,7 @@ py_library(
...
@@ -168,7 +191,7 @@ py_library(
name
=
"render_spec_with_graphviz"
,
name
=
"render_spec_with_graphviz"
,
srcs
=
[
"render_spec_with_graphviz.py"
],
srcs
=
[
"render_spec_with_graphviz.py"
],
deps
=
[
deps
=
[
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
],
],
)
)
...
@@ -197,7 +220,7 @@ py_binary(
...
@@ -197,7 +220,7 @@ py_binary(
"//dragnn/viz:viz-min-js-gz"
,
"//dragnn/viz:viz-min-js-gz"
,
],
],
deps
=
[
deps
=
[
"//dragnn/protos:trace_p
y_pb2
"
,
"//dragnn/protos:trace_p
b2_py
"
,
],
],
)
)
...
@@ -206,8 +229,8 @@ py_test(
...
@@ -206,8 +229,8 @@ py_test(
srcs
=
[
"visualization_test.py"
],
srcs
=
[
"visualization_test.py"
],
deps
=
[
deps
=
[
":visualization"
,
":visualization"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//dragnn/protos:trace_p
y_pb2
"
,
"//dragnn/protos:trace_p
b2_py
"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
],
)
)
...
@@ -225,6 +248,18 @@ py_library(
...
@@ -225,6 +248,18 @@ py_library(
# Tests
# Tests
py_test
(
name
=
"component_test"
,
srcs
=
[
"component_test.py"
,
],
deps
=
[
":components"
,
"//dragnn/protos:spec_pb2_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_test
(
py_test
(
name
=
"bulk_component_test"
,
name
=
"bulk_component_test"
,
srcs
=
[
srcs
=
[
...
@@ -235,9 +270,9 @@ py_test(
...
@@ -235,9 +270,9 @@ py_test(
":components"
,
":components"
,
":dragnn_ops"
,
":dragnn_ops"
,
":network_units"
,
":network_units"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:sentence_p
y_pb2
"
,
"//syntaxnet:sentence_p
b2_py
"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
],
...
@@ -270,10 +305,11 @@ py_test(
...
@@ -270,10 +305,11 @@ py_test(
deps
=
[
deps
=
[
":dragnn_ops"
,
":dragnn_ops"
,
":graph_builder"
,
":graph_builder"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//dragnn/protos:trace_p
y_pb2
"
,
"//dragnn/protos:trace_p
b2_py
"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:sentence_py_pb2"
,
"//syntaxnet:sentence_pb2_py"
,
"//syntaxnet:test_flags"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
],
...
@@ -287,7 +323,7 @@ py_test(
...
@@ -287,7 +323,7 @@ py_test(
":network_units"
,
":network_units"
,
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_ops"
,
"//dragnn/core:dragnn_ops"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:load_parser_ops_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
...
@@ -303,7 +339,8 @@ py_test(
...
@@ -303,7 +339,8 @@ py_test(
":sentence_io"
,
":sentence_io"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:sentence_py_pb2"
,
"//syntaxnet:sentence_pb2_py"
,
"//syntaxnet:test_flags"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
],
...
@@ -313,21 +350,31 @@ py_library(
...
@@ -313,21 +350,31 @@ py_library(
name
=
"trainer_lib"
,
name
=
"trainer_lib"
,
srcs
=
[
"trainer_lib.py"
],
srcs
=
[
"trainer_lib.py"
],
deps
=
[
deps
=
[
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:sentence_py_pb2"
,
"//syntaxnet:sentence_pb2_py"
,
"//syntaxnet:task_spec_py_pb2"
,
"//syntaxnet:task_spec_pb2_py"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
],
)
)
py_test
(
name
=
"trainer_lib_test"
,
srcs
=
[
"trainer_lib_test.py"
],
deps
=
[
":trainer_lib"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_library
(
py_library
(
name
=
"lexicon"
,
name
=
"lexicon"
,
srcs
=
[
"lexicon.py"
],
srcs
=
[
"lexicon.py"
],
deps
=
[
deps
=
[
"//syntaxnet:parser_ops"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:task_spec_p
y_pb2
"
,
"//syntaxnet:task_spec_p
b2_py
"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
],
)
)
...
@@ -340,6 +387,7 @@ py_test(
...
@@ -340,6 +387,7 @@ py_test(
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:parser_trainer"
,
"//syntaxnet:parser_trainer"
,
"//syntaxnet:test_flags"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
],
)
)
...
@@ -348,7 +396,7 @@ py_library(
...
@@ -348,7 +396,7 @@ py_library(
name
=
"evaluation"
,
name
=
"evaluation"
,
srcs
=
[
"evaluation.py"
],
srcs
=
[
"evaluation.py"
],
deps
=
[
deps
=
[
"//syntaxnet:sentence_p
y_pb2
"
,
"//syntaxnet:sentence_p
b2_py
"
,
"//syntaxnet/util:check"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
],
...
@@ -359,7 +407,7 @@ py_test(
...
@@ -359,7 +407,7 @@ py_test(
srcs
=
[
"evaluation_test.py"
],
srcs
=
[
"evaluation_test.py"
],
deps
=
[
deps
=
[
":evaluation"
,
":evaluation"
,
"//syntaxnet:sentence_p
y_pb2
"
,
"//syntaxnet:sentence_p
b2_py
"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
],
)
)
...
@@ -369,7 +417,7 @@ py_library(
...
@@ -369,7 +417,7 @@ py_library(
srcs
=
[
"spec_builder.py"
],
srcs
=
[
"spec_builder.py"
],
deps
=
[
deps
=
[
":lexicon"
,
":lexicon"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet/util:check"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
...
@@ -381,7 +429,7 @@ py_test(
...
@@ -381,7 +429,7 @@ py_test(
srcs
=
[
"spec_builder_test.py"
],
srcs
=
[
"spec_builder_test.py"
],
deps
=
[
deps
=
[
":spec_builder"
,
":spec_builder"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:parser_trainer"
,
"//syntaxnet:parser_trainer"
,
...
@@ -418,6 +466,17 @@ py_library(
...
@@ -418,6 +466,17 @@ py_library(
],
],
)
)
py_test
(
name
=
"biaffine_units_test"
,
srcs
=
[
"biaffine_units_test.py"
],
deps
=
[
":biaffine_units"
,
":network_units"
,
"//dragnn/protos:spec_pb2_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_library
(
py_library
(
name
=
"transformer_units"
,
name
=
"transformer_units"
,
srcs
=
[
"transformer_units.py"
],
srcs
=
[
"transformer_units.py"
],
...
@@ -437,10 +496,85 @@ py_test(
...
@@ -437,10 +496,85 @@ py_test(
":transformer_units"
,
":transformer_units"
,
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_ops"
,
"//dragnn/core:dragnn_ops"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:load_parser_ops_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
],
)
)
py_library
(
name
=
"runtime_support"
,
srcs
=
[
"runtime_support.py"
],
deps
=
[
":network_units"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_test
(
name
=
"runtime_support_test"
,
srcs
=
[
"runtime_support_test.py"
],
deps
=
[
":network_units"
,
":runtime_support"
,
"//dragnn/protos:export_pb2_py"
,
"//dragnn/protos:spec_pb2_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_library
(
name
=
"file_diff_test"
,
srcs
=
[
"file_diff_test.py"
],
deps
=
[
"@absl_py//absl/flags"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_library
(
name
=
"mst_ops"
,
srcs
=
[
"mst_ops.py"
],
visibility
=
[
"//visibility:public"
],
deps
=
[
":digraph_ops"
,
":load_mst_cc_impl_py"
,
"//dragnn/mst:mst_ops"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_test
(
name
=
"mst_ops_test"
,
srcs
=
[
"mst_ops_test.py"
],
deps
=
[
":mst_ops"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_library
(
name
=
"mst_units"
,
srcs
=
[
"mst_units.py"
],
deps
=
[
":mst_ops"
,
":network_units"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_test
(
name
=
"mst_units_test"
,
size
=
"small"
,
srcs
=
[
"mst_units_test.py"
],
deps
=
[
":mst_units"
,
":network_units"
,
"//dragnn/protos:spec_pb2_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
research/syntaxnet/dragnn/python/biaffine_units.py
View file @
80178fc6
...
@@ -79,24 +79,44 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
...
@@ -79,24 +79,44 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
self
.
_source_dim
=
self
.
_linked_feature_dims
[
'sources'
]
self
.
_source_dim
=
self
.
_linked_feature_dims
[
'sources'
]
self
.
_target_dim
=
self
.
_linked_feature_dims
[
'targets'
]
self
.
_target_dim
=
self
.
_linked_feature_dims
[
'targets'
]
# TODO(googleuser): Make parameter initialization configurable.
self
.
_weights
=
[]
self
.
_weights
=
[]
self
.
_weights
.
append
(
tf
.
get_variable
(
self
.
_weights
.
append
(
'weights_arc'
,
[
self
.
_source_dim
,
self
.
_target_dim
],
tf
.
float32
,
tf
.
get_variable
(
'weights_arc'
,
[
self
.
_source_dim
,
self
.
_target_dim
],
tf
.
random_norm
al_initializer
(
stddev
=
1e-4
)))
tf
.
float32
,
tf
.
orthogon
al_initializer
()))
self
.
_weights
.
append
(
tf
.
get_variable
(
self
.
_weights
.
append
(
'weights_source'
,
[
self
.
_source_dim
],
tf
.
float32
,
tf
.
get_variable
(
'weights_source'
,
[
self
.
_source_dim
],
tf
.
float32
,
tf
.
random_normal
_initializer
(
stddev
=
1e-4
)))
tf
.
zeros
_initializer
()))
self
.
_weights
.
append
(
tf
.
get_variable
(
self
.
_weights
.
append
(
'root'
,
[
self
.
_source_dim
],
tf
.
float32
,
tf
.
get_variable
(
'root'
,
[
self
.
_source_dim
],
tf
.
float32
,
tf
.
random_normal
_initializer
(
stddev
=
1e-4
)))
tf
.
zeros
_initializer
()))
self
.
_params
.
extend
(
self
.
_weights
)
self
.
_params
.
extend
(
self
.
_weights
)
self
.
_regularized_weights
.
extend
(
self
.
_weights
)
self
.
_regularized_weights
.
extend
(
self
.
_weights
)
# Add runtime hooks for pre-computed weights.
self
.
_derived_params
.
append
(
self
.
_get_root_weights
)
self
.
_derived_params
.
append
(
self
.
_get_root_bias
)
# Negative Layer.dim indicates that the dimension is dynamic.
# Negative Layer.dim indicates that the dimension is dynamic.
self
.
_layers
.
append
(
network_units
.
Layer
(
component
,
'adjacency'
,
-
1
))
self
.
_layers
.
append
(
network_units
.
Layer
(
component
,
'adjacency'
,
-
1
))
def
_get_root_weights
(
self
):
"""Pre-computes the product of the root embedding and arc weights."""
weights_arc
=
self
.
_component
.
get_variable
(
'weights_arc'
)
root
=
self
.
_component
.
get_variable
(
'root'
)
name
=
self
.
_component
.
name
+
'/root_weights'
with
tf
.
name_scope
(
None
):
return
tf
.
matmul
(
tf
.
expand_dims
(
root
,
0
),
weights_arc
,
name
=
name
)
def
_get_root_bias
(
self
):
"""Pre-computes the product of the root embedding and source weights."""
weights_source
=
self
.
_component
.
get_variable
(
'weights_source'
)
root
=
self
.
_component
.
get_variable
(
'root'
)
name
=
self
.
_component
.
name
+
'/root_bias'
with
tf
.
name_scope
(
None
):
return
tf
.
matmul
(
tf
.
expand_dims
(
root
,
0
),
tf
.
expand_dims
(
weights_source
,
1
),
name
=
name
)
def
create
(
self
,
def
create
(
self
,
fixed_embeddings
,
fixed_embeddings
,
linked_embeddings
,
linked_embeddings
,
...
@@ -133,12 +153,17 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
...
@@ -133,12 +153,17 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
sources_bxnxn
=
digraph_ops
.
ArcSourcePotentialsFromTokens
(
sources_bxnxn
=
digraph_ops
.
ArcSourcePotentialsFromTokens
(
source_tokens_bxnxs
,
weights_source
)
source_tokens_bxnxs
,
weights_source
)
roots_bxn
=
digraph_ops
.
RootPotentialsFromTokens
(
roots_bxn
=
digraph_ops
.
RootPotentialsFromTokens
(
root
,
target_tokens_bxnxt
,
weights_arc
)
root
,
target_tokens_bxnxt
,
weights_arc
,
weights_source
)
# Combine them into a single matrix with the roots on the diagonal.
# Combine them into a single matrix with the roots on the diagonal.
adjacency_bxnxn
=
digraph_ops
.
CombineArcAndRootPotentials
(
adjacency_bxnxn
=
digraph_ops
.
CombineArcAndRootPotentials
(
arcs_bxnxn
+
sources_bxnxn
,
roots_bxn
)
arcs_bxnxn
+
sources_bxnxn
,
roots_bxn
)
# The adjacency matrix currently has sources on rows and targets on columns,
# but we want targets on rows so that maximizing within a row corresponds to
# selecting sources for a given target.
adjacency_bxnxn
=
tf
.
matrix_transpose
(
adjacency_bxnxn
)
return
[
tf
.
reshape
(
adjacency_bxnxn
,
[
-
1
,
num_tokens
])]
return
[
tf
.
reshape
(
adjacency_bxnxn
,
[
-
1
,
num_tokens
])]
...
...
research/syntaxnet/dragnn/python/biaffine_units_test.py
0 → 100644
View file @
80178fc6
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for biaffine_units."""
import
tensorflow
as
tf
from
google.protobuf
import
text_format
from
dragnn.protos
import
spec_pb2
from
dragnn.python
import
biaffine_units
from
dragnn.python
import
network_units
_BATCH_SIZE
=
11
_NUM_TOKENS
=
22
_TOKEN_DIM
=
33
class
MockNetwork
(
object
):
def
__init__
(
self
):
pass
def
get_layer_size
(
self
,
unused_name
):
return
_TOKEN_DIM
class
MockComponent
(
object
):
def
__init__
(
self
,
master
,
component_spec
):
self
.
master
=
master
self
.
spec
=
component_spec
self
.
name
=
component_spec
.
name
self
.
network
=
MockNetwork
()
self
.
beam_size
=
1
self
.
num_actions
=
45
self
.
_attrs
=
{}
def
attr
(
self
,
name
):
return
self
.
_attrs
[
name
]
def
get_variable
(
self
,
name
):
return
tf
.
get_variable
(
name
)
class
MockMaster
(
object
):
def
__init__
(
self
):
self
.
spec
=
spec_pb2
.
MasterSpec
()
self
.
hyperparams
=
spec_pb2
.
GridPoint
()
self
.
lookup_component
=
{
'previous'
:
MockComponent
(
self
,
spec_pb2
.
ComponentSpec
())
}
def
_make_biaffine_spec
():
"""Returns a ComponentSpec that the BiaffineDigraphNetwork works on."""
component_spec
=
spec_pb2
.
ComponentSpec
()
text_format
.
Parse
(
"""
name: "test_component"
backend { registered_name: "TestComponent" }
linked_feature {
name: "sources"
fml: "input.focus"
source_translator: "identity"
source_component: "previous"
source_layer: "sources"
size: 1
embedding_dim: -1
}
linked_feature {
name: "targets"
fml: "input.focus"
source_translator: "identity"
source_component: "previous"
source_layer: "targets"
size: 1
embedding_dim: -1
}
network_unit {
registered_name: "biaffine_units.BiaffineDigraphNetwork"
}
"""
,
component_spec
)
return
component_spec
class
BiaffineDigraphNetworkTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf
.
reset_default_graph
()
def
testCanCreate
(
self
):
"""Tests that create() works on a good spec."""
with
tf
.
Graph
().
as_default
(),
self
.
test_session
():
master
=
MockMaster
()
component
=
MockComponent
(
master
,
_make_biaffine_spec
())
with
tf
.
variable_scope
(
component
.
name
,
reuse
=
None
):
component
.
network
=
biaffine_units
.
BiaffineDigraphNetwork
(
component
)
with
tf
.
variable_scope
(
component
.
name
,
reuse
=
True
):
sources
=
network_units
.
NamedTensor
(
tf
.
zeros
([
_BATCH_SIZE
*
_NUM_TOKENS
,
_TOKEN_DIM
]),
'sources'
)
targets
=
network_units
.
NamedTensor
(
tf
.
zeros
([
_BATCH_SIZE
*
_NUM_TOKENS
,
_TOKEN_DIM
]),
'targets'
)
# No assertions on the result, just don't crash.
component
.
network
.
create
(
fixed_embeddings
=
[],
linked_embeddings
=
[
sources
,
targets
],
context_tensor_arrays
=
None
,
attention_tensor
=
None
,
during_training
=
True
,
stride
=
_BATCH_SIZE
)
def
testDerivedParametersForRuntime
(
self
):
"""Test generation of derived parameters for the runtime."""
with
tf
.
Graph
().
as_default
(),
self
.
test_session
():
master
=
MockMaster
()
component
=
MockComponent
(
master
,
_make_biaffine_spec
())
with
tf
.
variable_scope
(
component
.
name
,
reuse
=
None
):
component
.
network
=
biaffine_units
.
BiaffineDigraphNetwork
(
component
)
with
tf
.
variable_scope
(
component
.
name
,
reuse
=
True
):
self
.
assertEqual
(
len
(
component
.
network
.
derived_params
),
2
)
root_weights
=
component
.
network
.
derived_params
[
0
]()
root_bias
=
component
.
network
.
derived_params
[
1
]()
# Only check shape; values are random due to initialization.
self
.
assertAllEqual
(
root_weights
.
shape
.
as_list
(),
[
1
,
_TOKEN_DIM
])
self
.
assertAllEqual
(
root_bias
.
shape
.
as_list
(),
[
1
,
1
])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/syntaxnet/dragnn/python/bulk_component.py
View file @
80178fc6
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Component builders for non-recurrent networks in DRAGNN."""
"""Component builders for non-recurrent networks in DRAGNN."""
...
@@ -51,10 +50,8 @@ def fetch_linked_embedding(comp, network_states, feature_spec):
...
@@ -51,10 +50,8 @@ def fetch_linked_embedding(comp, network_states, feature_spec):
feature_spec
.
name
)
feature_spec
.
name
)
source
=
comp
.
master
.
lookup_component
[
feature_spec
.
source_component
]
source
=
comp
.
master
.
lookup_component
[
feature_spec
.
source_component
]
return
network_units
.
NamedTensor
(
return
network_units
.
NamedTensor
(
network_states
[
source
.
name
].
activations
[
network_states
[
source
.
name
].
activations
[
feature_spec
.
source_layer
].
bulk_tensor
,
feature_spec
.
name
)
feature_spec
.
source_layer
].
bulk_tensor
,
feature_spec
.
name
)
def
_validate_embedded_fixed_features
(
comp
):
def
_validate_embedded_fixed_features
(
comp
):
...
@@ -63,17 +60,20 @@ def _validate_embedded_fixed_features(comp):
...
@@ -63,17 +60,20 @@ def _validate_embedded_fixed_features(comp):
check
.
Gt
(
feature
.
embedding_dim
,
0
,
check
.
Gt
(
feature
.
embedding_dim
,
0
,
'Embeddings requested for non-embedded feature: %s'
%
feature
)
'Embeddings requested for non-embedded feature: %s'
%
feature
)
if
feature
.
is_constant
:
if
feature
.
is_constant
:
check
.
IsTrue
(
feature
.
HasField
(
'pretrained_embedding_matrix'
),
check
.
IsTrue
(
'Constant embeddings must be pretrained: %s'
%
feature
)
feature
.
HasField
(
'pretrained_embedding_matrix'
),
'Constant embeddings must be pretrained: %s'
%
feature
)
def
fetch_differentiable_fixed_embeddings
(
comp
,
state
,
stride
):
def
fetch_differentiable_fixed_embeddings
(
comp
,
state
,
stride
,
during_training
):
"""Looks up fixed features with separate, differentiable, embedding lookup.
"""Looks up fixed features with separate, differentiable, embedding lookup.
Args:
Args:
comp: Component whose fixed features we wish to look up.
comp: Component whose fixed features we wish to look up.
state: live MasterState object for the component.
state: live MasterState object for the component.
stride: Tensor containing current batch * beam size.
stride: Tensor containing current batch * beam size.
during_training: True if this is being called from a training code path.
This controls, e.g., the use of feature ID dropout.
Returns:
Returns:
state handle: updated state handle to be used after this call
state handle: updated state handle to be used after this call
...
@@ -93,6 +93,11 @@ def fetch_differentiable_fixed_embeddings(comp, state, stride):
...
@@ -93,6 +93,11 @@ def fetch_differentiable_fixed_embeddings(comp, state, stride):
'differentiable'
)
'differentiable'
)
tf
.
logging
.
info
(
'[%s] Adding %s fixed feature "%s"'
,
comp
.
name
,
tf
.
logging
.
info
(
'[%s] Adding %s fixed feature "%s"'
,
comp
.
name
,
differentiable_or_constant
,
feature_spec
.
name
)
differentiable_or_constant
,
feature_spec
.
name
)
if
during_training
and
feature_spec
.
dropout_id
>=
0
:
ids
[
channel
],
weights
[
channel
]
=
network_units
.
apply_feature_id_dropout
(
ids
[
channel
],
weights
[
channel
],
feature_spec
)
size
=
stride
*
num_steps
*
feature_spec
.
size
size
=
stride
*
num_steps
*
feature_spec
.
size
fixed_embedding
=
network_units
.
embedding_lookup
(
fixed_embedding
=
network_units
.
embedding_lookup
(
comp
.
get_variable
(
network_units
.
fixed_embeddings_name
(
channel
)),
comp
.
get_variable
(
network_units
.
fixed_embeddings_name
(
channel
)),
...
@@ -105,16 +110,22 @@ def fetch_differentiable_fixed_embeddings(comp, state, stride):
...
@@ -105,16 +110,22 @@ def fetch_differentiable_fixed_embeddings(comp, state, stride):
return
state
.
handle
,
fixed_embeddings
return
state
.
handle
,
fixed_embeddings
def
fetch_fast_fixed_embeddings
(
comp
,
state
):
def
fetch_fast_fixed_embeddings
(
comp
,
state
,
pad_to_batch
=
None
,
pad_to_steps
=
None
):
"""Looks up fixed features with fast, non-differentiable, op.
"""Looks up fixed features with fast, non-differentiable, op.
Since BulkFixedEmbeddings is non-differentiable with respect to the
Since BulkFixedEmbeddings is non-differentiable with respect to the
embeddings, the idea is to call this function only when the graph is
embeddings, the idea is to call this function only when the graph is
not being used for training.
not being used for training. If the function is being called with fixed step
and batch sizes, it will use the most efficient possible extractor.
Args:
Args:
comp: Component whose fixed features we wish to look up.
comp: Component whose fixed features we wish to look up.
state: live MasterState object for the component.
state: live MasterState object for the component.
pad_to_batch: Optional; the number of batch elements to pad to.
pad_to_steps: Optional; the number of steps to pad to.
Returns:
Returns:
state handle: updated state handle to be used after this call
state handle: updated state handle to be used after this call
...
@@ -126,19 +137,50 @@ def fetch_fast_fixed_embeddings(comp, state):
...
@@ -126,19 +137,50 @@ def fetch_fast_fixed_embeddings(comp, state):
return
state
.
handle
,
[]
return
state
.
handle
,
[]
tf
.
logging
.
info
(
'[%s] Adding %d fast fixed features'
,
comp
.
name
,
num_channels
)
tf
.
logging
.
info
(
'[%s] Adding %d fast fixed features'
,
comp
.
name
,
num_channels
)
state
.
handle
,
bulk_embeddings
,
_
=
dragnn_ops
.
bulk_fixed_embeddings
(
features
=
[
state
.
handle
,
[
comp
.
get_variable
(
network_units
.
fixed_embeddings_name
(
c
))
comp
.
get_variable
(
network_units
.
fixed_embeddings_name
(
c
))
for
c
in
range
(
num_channels
)
for
c
in
range
(
num_channels
)
]
],
component
=
comp
.
name
)
if
pad_to_batch
is
not
None
and
pad_to_steps
is
not
None
:
# If we have fixed padding numbers, we can use 'bulk_embed_fixed_features',
bulk_embeddings
=
network_units
.
NamedTensor
(
bulk_embeddings
,
# which is the fastest embedding extractor.
'bulk-%s-fixed-features'
%
state
.
handle
,
bulk_embeddings
,
_
=
dragnn_ops
.
bulk_embed_fixed_features
(
comp
.
name
)
state
.
handle
,
features
,
component
=
comp
.
name
,
pad_to_batch
=
pad_to_batch
,
pad_to_steps
=
pad_to_steps
)
else
:
state
.
handle
,
bulk_embeddings
,
_
=
dragnn_ops
.
bulk_fixed_embeddings
(
state
.
handle
,
features
,
component
=
comp
.
name
)
bulk_embeddings
=
network_units
.
NamedTensor
(
bulk_embeddings
,
'bulk-%s-fixed-features'
%
comp
.
name
)
return
state
.
handle
,
[
bulk_embeddings
]
return
state
.
handle
,
[
bulk_embeddings
]
def
fetch_dense_ragged_embeddings
(
comp
,
state
):
"""Gets embeddings in RaggedTensor format."""
_validate_embedded_fixed_features
(
comp
)
num_channels
=
len
(
comp
.
spec
.
fixed_feature
)
if
not
num_channels
:
return
state
.
handle
,
[]
tf
.
logging
.
info
(
'[%s] Adding %d fast fixed features'
,
comp
.
name
,
num_channels
)
features
=
[
comp
.
get_variable
(
network_units
.
fixed_embeddings_name
(
c
))
for
c
in
range
(
num_channels
)
]
state
.
handle
,
data
,
offsets
=
dragnn_ops
.
bulk_embed_dense_fixed_features
(
state
.
handle
,
features
,
component
=
comp
.
name
)
data
=
network_units
.
NamedTensor
(
data
,
'dense-%s-data'
%
comp
.
name
)
offsets
=
network_units
.
NamedTensor
(
offsets
,
'dense-%s-offsets'
%
comp
.
name
)
return
state
.
handle
,
[
data
,
offsets
]
def
extract_fixed_feature_ids
(
comp
,
state
,
stride
):
def
extract_fixed_feature_ids
(
comp
,
state
,
stride
):
"""Extracts fixed feature IDs.
"""Extracts fixed feature IDs.
...
@@ -194,8 +236,10 @@ def update_network_states(comp, tensors, network_states, stride):
...
@@ -194,8 +236,10 @@ def update_network_states(comp, tensors, network_states, stride):
with
tf
.
name_scope
(
comp
.
name
+
'/stored_act'
):
with
tf
.
name_scope
(
comp
.
name
+
'/stored_act'
):
for
index
,
network_tensor
in
enumerate
(
tensors
):
for
index
,
network_tensor
in
enumerate
(
tensors
):
network_state
.
activations
[
comp
.
network
.
layers
[
index
].
name
]
=
(
network_state
.
activations
[
comp
.
network
.
layers
[
index
].
name
]
=
(
network_units
.
StoredActivations
(
tensor
=
network_tensor
,
stride
=
stride
,
network_units
.
StoredActivations
(
dim
=
comp
.
network
.
layers
[
index
].
dim
))
tensor
=
network_tensor
,
stride
=
stride
,
dim
=
comp
.
network
.
layers
[
index
].
dim
))
def
build_cross_entropy_loss
(
logits
,
gold
):
def
build_cross_entropy_loss
(
logits
,
gold
):
...
@@ -205,7 +249,7 @@ def build_cross_entropy_loss(logits, gold):
...
@@ -205,7 +249,7 @@ def build_cross_entropy_loss(logits, gold):
Args:
Args:
logits: float Tensor of scores.
logits: float Tensor of scores.
gold: int Tensor of
one-hot
labels.
gold: int Tensor of
gold
label
id
s.
Returns:
Returns:
cost, correct, total: the total cost, the total number of correctly
cost, correct, total: the total cost, the total number of correctly
...
@@ -251,9 +295,10 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
...
@@ -251,9 +295,10 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
"""
"""
logging
.
info
(
'Building component: %s'
,
self
.
spec
.
name
)
logging
.
info
(
'Building component: %s'
,
self
.
spec
.
name
)
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
self
.
network
.
pre_create
(
stride
)
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
state
.
handle
,
fixed_embeddings
=
fetch_differentiable_fixed_embeddings
(
state
.
handle
,
fixed_embeddings
=
fetch_differentiable_fixed_embeddings
(
self
,
state
,
stride
)
self
,
state
,
stride
,
True
)
linked_embeddings
=
[
linked_embeddings
=
[
fetch_linked_embedding
(
self
,
network_states
,
spec
)
fetch_linked_embedding
(
self
,
network_states
,
spec
)
...
@@ -307,14 +352,29 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
...
@@ -307,14 +352,29 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
else
:
else
:
stride
=
state
.
current_batch_size
*
self
.
inference_beam_size
stride
=
state
.
current_batch_size
*
self
.
inference_beam_size
self
.
network
.
pre_create
(
stride
)
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
if
during_training
:
if
during_training
:
state
.
handle
,
fixed_embeddings
=
fetch_differentiable_fixed_embeddings
(
state
.
handle
,
fixed_embeddings
=
fetch_differentiable_fixed_embeddings
(
self
,
state
,
stride
)
self
,
state
,
stride
,
during_training
)
else
:
else
:
state
.
handle
,
fixed_embeddings
=
fetch_fast_fixed_embeddings
(
self
,
if
'use_densors'
in
self
.
spec
.
network_unit
.
parameters
:
state
)
state
.
handle
,
fixed_embeddings
=
fetch_dense_ragged_embeddings
(
self
,
state
)
else
:
if
(
'padded_batch_size'
in
self
.
spec
.
network_unit
.
parameters
and
'padded_sentence_length'
in
self
.
spec
.
network_unit
.
parameters
):
state
.
handle
,
fixed_embeddings
=
fetch_fast_fixed_embeddings
(
self
,
state
,
pad_to_batch
=-
1
,
pad_to_steps
=
int
(
self
.
spec
.
network_unit
.
parameters
[
'padded_sentence_length'
]))
else
:
state
.
handle
,
fixed_embeddings
=
fetch_fast_fixed_embeddings
(
self
,
state
)
linked_embeddings
=
[
linked_embeddings
=
[
fetch_linked_embedding
(
self
,
network_states
,
spec
)
fetch_linked_embedding
(
self
,
network_states
,
spec
)
...
@@ -331,6 +391,7 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
...
@@ -331,6 +391,7 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
stride
=
stride
)
stride
=
stride
)
update_network_states
(
self
,
tensors
,
network_states
,
stride
)
update_network_states
(
self
,
tensors
,
network_states
,
stride
)
self
.
_add_runtime_hooks
()
return
state
.
handle
return
state
.
handle
...
@@ -367,7 +428,9 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
...
@@ -367,7 +428,9 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
def
build_greedy_inference
(
self
,
state
,
network_states
,
def
build_greedy_inference
(
self
,
state
,
network_states
,
during_training
=
False
):
during_training
=
False
):
"""See base class."""
"""See base class."""
return
self
.
_extract_feature_ids
(
state
,
network_states
,
during_training
)
handle
=
self
.
_extract_feature_ids
(
state
,
network_states
,
during_training
)
self
.
_add_runtime_hooks
()
return
handle
def
_extract_feature_ids
(
self
,
state
,
network_states
,
during_training
):
def
_extract_feature_ids
(
self
,
state
,
network_states
,
during_training
):
"""Extracts feature IDs and advances a batch using the oracle path.
"""Extracts feature IDs and advances a batch using the oracle path.
...
@@ -387,6 +450,7 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
...
@@ -387,6 +450,7 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
else
:
else
:
stride
=
state
.
current_batch_size
*
self
.
inference_beam_size
stride
=
state
.
current_batch_size
*
self
.
inference_beam_size
self
.
network
.
pre_create
(
stride
)
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
state
.
handle
,
ids
=
extract_fixed_feature_ids
(
self
,
state
,
stride
)
state
.
handle
,
ids
=
extract_fixed_feature_ids
(
self
,
state
,
stride
)
...
@@ -438,17 +502,21 @@ class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
...
@@ -438,17 +502,21 @@ class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
]
]
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
self
.
network
.
pre_create
(
stride
)
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
network_tensors
=
self
.
network
.
create
([],
linked_embeddings
,
None
,
None
,
network_tensors
=
self
.
network
.
create
([],
linked_embeddings
,
None
,
None
,
True
,
stride
)
True
,
stride
)
update_network_states
(
self
,
network_tensors
,
network_states
,
stride
)
update_network_states
(
self
,
network_tensors
,
network_states
,
stride
)
logits
=
self
.
network
.
get_logits
(
network_tensors
)
state
.
handle
,
gold
=
dragnn_ops
.
bulk_advance_from_oracle
(
state
.
handle
,
gold
=
dragnn_ops
.
bulk_advance_from_oracle
(
state
.
handle
,
component
=
self
.
name
)
state
.
handle
,
component
=
self
.
name
)
cost
,
correct
,
total
=
self
.
network
.
compute_bulk_loss
(
cost
,
correct
,
total
=
build_cross_entropy_loss
(
logits
,
gold
)
stride
,
network_tensors
,
gold
)
if
cost
is
None
:
# The network does not have a custom bulk loss; default to softmax.
logits
=
self
.
network
.
get_logits
(
network_tensors
)
cost
,
correct
,
total
=
build_cross_entropy_loss
(
logits
,
gold
)
cost
=
self
.
add_regularizer
(
cost
)
cost
=
self
.
add_regularizer
(
cost
)
return
state
.
handle
,
cost
,
correct
,
total
return
state
.
handle
,
cost
,
correct
,
total
...
@@ -483,13 +551,24 @@ class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
...
@@ -483,13 +551,24 @@ class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
else
:
else
:
stride
=
state
.
current_batch_size
*
self
.
inference_beam_size
stride
=
state
.
current_batch_size
*
self
.
inference_beam_size
self
.
network
.
pre_create
(
stride
)
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
network_tensors
=
self
.
network
.
create
(
network_tensors
=
self
.
network
.
create
(
[],
linked_embeddings
,
None
,
None
,
[],
linked_embeddings
,
None
,
None
,
during_training
,
stride
)
during_training
,
stride
)
update_network_states
(
self
,
network_tensors
,
network_states
,
stride
)
update_network_states
(
self
,
network_tensors
,
network_states
,
stride
)
logits
=
self
.
network
.
get_logits
(
network_tensors
)
logits
=
self
.
network
.
get_bulk_predictions
(
stride
,
network_tensors
)
return
dragnn_ops
.
bulk_advance_from_prediction
(
if
logits
is
None
:
# The network does not produce custom bulk predictions; default to logits.
logits
=
self
.
network
.
get_logits
(
network_tensors
)
logits
=
tf
.
cond
(
self
.
locally_normalize
,
lambda
:
tf
.
nn
.
log_softmax
(
logits
),
lambda
:
logits
)
if
self
.
_output_as_probabilities
:
logits
=
tf
.
nn
.
softmax
(
logits
)
handle
=
dragnn_ops
.
bulk_advance_from_prediction
(
state
.
handle
,
logits
,
component
=
self
.
name
)
state
.
handle
,
logits
,
component
=
self
.
name
)
self
.
_add_runtime_hooks
()
return
handle
research/syntaxnet/dragnn/python/bulk_component_test.py
View file @
80178fc6
...
@@ -41,8 +41,6 @@ from dragnn.python import dragnn_ops
...
@@ -41,8 +41,6 @@ from dragnn.python import dragnn_ops
from
dragnn.python
import
network_units
from
dragnn.python
import
network_units
from
syntaxnet
import
sentence_pb2
from
syntaxnet
import
sentence_pb2
FLAGS
=
tf
.
app
.
flags
.
FLAGS
class
MockNetworkUnit
(
object
):
class
MockNetworkUnit
(
object
):
...
@@ -63,6 +61,7 @@ class MockMaster(object):
...
@@ -63,6 +61,7 @@ class MockMaster(object):
self
.
spec
=
spec_pb2
.
MasterSpec
()
self
.
spec
=
spec_pb2
.
MasterSpec
()
self
.
hyperparams
=
spec_pb2
.
GridPoint
()
self
.
hyperparams
=
spec_pb2
.
GridPoint
()
self
.
lookup_component
=
{
'mock'
:
MockComponent
()}
self
.
lookup_component
=
{
'mock'
:
MockComponent
()}
self
.
build_runtime_graph
=
False
def
_create_fake_corpus
():
def
_create_fake_corpus
():
...
@@ -84,9 +83,12 @@ def _create_fake_corpus():
...
@@ -84,9 +83,12 @@ def _create_fake_corpus():
class
BulkComponentTest
(
test_util
.
TensorFlowTestCase
):
class
BulkComponentTest
(
test_util
.
TensorFlowTestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf
.
reset_default_graph
()
self
.
master
=
MockMaster
()
self
.
master
=
MockMaster
()
self
.
master_state
=
component
.
MasterState
(
self
.
master_state
=
component
.
MasterState
(
handle
=
'handle'
,
current_batch_size
=
2
)
handle
=
tf
.
constant
([
'foo'
,
'bar'
])
,
current_batch_size
=
2
)
self
.
network_states
=
{
self
.
network_states
=
{
'mock'
:
component
.
NetworkState
(),
'mock'
:
component
.
NetworkState
(),
'test'
:
component
.
NetworkState
(),
'test'
:
component
.
NetworkState
(),
...
@@ -107,22 +109,21 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
...
@@ -107,22 +109,21 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
"""
,
component_spec
)
"""
,
component_spec
)
# For feature extraction:
# For feature extraction:
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
self
.
master
,
component_spec
)
# Expect feature extraction to generate a error due to the "history"
# Expect feature extraction to generate a error due to the "history"
# translator.
# translator.
with
self
.
assertRaises
(
NotImplementedError
):
with
self
.
assertRaises
(
NotImplementedError
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
# As well as annotation:
# As well as annotation:
with
tf
.
Graph
().
as_default
()
:
self
.
setUp
()
comp
=
bulk_component
.
BulkAnnotatorComponentBuilder
(
comp
=
bulk_component
.
BulkAnnotatorComponentBuilder
(
self
.
master
,
self
.
master
,
component_spec
)
component_spec
)
with
self
.
assertRaises
(
NotImplementedError
):
with
self
.
assertRaises
(
NotImplementedError
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
def
testFailsOnRecurrentLinkedFeature
(
self
):
def
testFailsOnRecurrentLinkedFeature
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
component_spec
=
spec_pb2
.
ComponentSpec
()
...
@@ -143,22 +144,21 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
...
@@ -143,22 +144,21 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
"""
,
component_spec
)
"""
,
component_spec
)
# For feature extraction:
# For feature extraction:
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
self
.
master
,
component_spec
)
# Expect feature extraction to generate a error due to the "history"
# Expect feature extraction to generate a error due to the "history"
# translator.
# translator.
with
self
.
assertRaises
(
RuntimeError
):
with
self
.
assertRaises
(
RuntimeError
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
# As well as annotation:
# As well as annotation:
with
tf
.
Graph
().
as_default
()
:
self
.
setUp
()
comp
=
bulk_component
.
BulkAnnotatorComponentBuilder
(
comp
=
bulk_component
.
BulkAnnotatorComponentBuilder
(
self
.
master
,
self
.
master
,
component_spec
)
component_spec
)
with
self
.
assertRaises
(
RuntimeError
):
with
self
.
assertRaises
(
RuntimeError
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
def
testConstantFixedFeatureFailsIfNotPretrained
(
self
):
def
testConstantFixedFeatureFailsIfNotPretrained
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
component_spec
=
spec_pb2
.
ComponentSpec
()
...
@@ -175,21 +175,20 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
...
@@ -175,21 +175,20 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
}
}
"""
,
component_spec
)
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
self
.
master
,
component_spec
)
with
self
.
assertRaisesRegexp
(
ValueError
,
with
self
.
assertRaisesRegexp
(
ValueError
,
'Constant embeddings must be pretrained'
):
'Constant embeddings must be pretrained'
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
with
self
.
assertRaisesRegexp
(
ValueError
,
with
self
.
assertRaisesRegexp
(
ValueError
,
'Constant embeddings must be pretrained'
):
'Constant embeddings must be pretrained'
):
comp
.
build_greedy_inference
(
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
,
during_training
=
True
)
self
.
master_state
,
self
.
network_states
,
during_training
=
True
)
with
self
.
assertRaisesRegexp
(
ValueError
,
with
self
.
assertRaisesRegexp
(
ValueError
,
'Constant embeddings must be pretrained'
):
'Constant embeddings must be pretrained'
):
comp
.
build_greedy_inference
(
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
,
during_training
=
False
)
self
.
master_state
,
self
.
network_states
,
during_training
=
False
)
def
testNormalFixedFeaturesAreDifferentiable
(
self
):
def
testNormalFixedFeaturesAreDifferentiable
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
component_spec
=
spec_pb2
.
ComponentSpec
()
...
@@ -207,25 +206,24 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
...
@@ -207,25 +206,24 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
}
}
"""
,
component_spec
)
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
self
.
master
,
component_spec
)
# Get embedding matrix variables.
# Get embedding matrix variables.
with
tf
.
variable_scope
(
comp
.
name
,
reuse
=
True
):
with
tf
.
variable_scope
(
comp
.
name
,
reuse
=
True
):
fixed_embedding_matrix
=
tf
.
get_variable
(
fixed_embedding_matrix
=
tf
.
get_variable
(
network_units
.
fixed_embeddings_name
(
0
))
network_units
.
fixed_embeddings_name
(
0
))
# Get output layer.
# Get output layer.
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
activations
=
self
.
network_states
[
comp
.
name
].
activations
activations
=
self
.
network_states
[
comp
.
name
].
activations
outputs
=
activations
[
comp
.
network
.
layers
[
0
].
name
].
bulk_tensor
outputs
=
activations
[
comp
.
network
.
layers
[
0
].
name
].
bulk_tensor
# Compute the gradient of the output layer w.r.t. the embedding matrix.
# Compute the gradient of the output layer w.r.t. the embedding matrix.
# This should be well-defined for in the normal case.
# This should be well-defined for in the normal case.
gradients
=
tf
.
gradients
(
outputs
,
fixed_embedding_matrix
)
gradients
=
tf
.
gradients
(
outputs
,
fixed_embedding_matrix
)
self
.
assertEqual
(
len
(
gradients
),
1
)
self
.
assertEqual
(
len
(
gradients
),
1
)
self
.
assertFalse
(
gradients
[
0
]
is
None
)
self
.
assertFalse
(
gradients
[
0
]
is
None
)
def
testConstantFixedFeaturesAreNotDifferentiableButOthersAre
(
self
):
def
testConstantFixedFeaturesAreNotDifferentiableButOthersAre
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
component_spec
=
spec_pb2
.
ComponentSpec
()
...
@@ -249,31 +247,30 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
...
@@ -249,31 +247,30 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
}
}
"""
,
component_spec
)
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
self
.
master
,
component_spec
)
# Get embedding matrix variables.
# Get embedding matrix variables.
with
tf
.
variable_scope
(
comp
.
name
,
reuse
=
True
):
with
tf
.
variable_scope
(
comp
.
name
,
reuse
=
True
):
constant_embedding_matrix
=
tf
.
get_variable
(
constant_embedding_matrix
=
tf
.
get_variable
(
network_units
.
fixed_embeddings_name
(
0
))
network_units
.
fixed_embeddings_name
(
0
))
trainable_embedding_matrix
=
tf
.
get_variable
(
trainable_embedding_matrix
=
tf
.
get_variable
(
network_units
.
fixed_embeddings_name
(
1
))
network_units
.
fixed_embeddings_name
(
1
))
# Get output layer.
# Get output layer.
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
activations
=
self
.
network_states
[
comp
.
name
].
activations
activations
=
self
.
network_states
[
comp
.
name
].
activations
outputs
=
activations
[
comp
.
network
.
layers
[
0
].
name
].
bulk_tensor
outputs
=
activations
[
comp
.
network
.
layers
[
0
].
name
].
bulk_tensor
# The constant embeddings are non-differentiable.
# The constant embeddings are non-differentiable.
constant_gradients
=
tf
.
gradients
(
outputs
,
constant_embedding_matrix
)
constant_gradients
=
tf
.
gradients
(
outputs
,
constant_embedding_matrix
)
self
.
assertEqual
(
len
(
constant_gradients
),
1
)
self
.
assertEqual
(
len
(
constant_gradients
),
1
)
self
.
assertTrue
(
constant_gradients
[
0
]
is
None
)
self
.
assertTrue
(
constant_gradients
[
0
]
is
None
)
# The trainable embeddings are differentiable.
# The trainable embeddings are differentiable.
trainable_gradients
=
tf
.
gradients
(
outputs
,
trainable_embedding_matrix
)
trainable_gradients
=
tf
.
gradients
(
outputs
,
trainable_embedding_matrix
)
self
.
assertEqual
(
len
(
trainable_gradients
),
1
)
self
.
assertEqual
(
len
(
trainable_gradients
),
1
)
self
.
assertFalse
(
trainable_gradients
[
0
]
is
None
)
self
.
assertFalse
(
trainable_gradients
[
0
]
is
None
)
def
testFailsOnFixedFeature
(
self
):
def
testFailsOnFixedFeature
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
component_spec
=
spec_pb2
.
ComponentSpec
()
...
@@ -306,15 +303,14 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
...
@@ -306,15 +303,14 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
name: "fixed" embedding_dim: -1 size: 1
name: "fixed" embedding_dim: -1 size: 1
}
}
"""
,
component_spec
)
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
self
.
master
,
component_spec
)
# Should not raise errors.
# Should not raise errors.
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
)
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
)
def
testBulkFeatureIdExtractorFailsOnLinkedFeature
(
self
):
def
testBulkFeatureIdExtractorFailsOnLinkedFeature
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
component_spec
=
spec_pb2
.
ComponentSpec
()
...
@@ -332,10 +328,9 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
...
@@ -332,10 +328,9 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
source_component: "mock"
source_component: "mock"
}
}
"""
,
component_spec
)
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
unused_comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
unused_comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
self
.
master
,
component_spec
)
def
testBulkFeatureIdExtractorOkWithMultipleFixedFeatures
(
self
):
def
testBulkFeatureIdExtractorOkWithMultipleFixedFeatures
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
component_spec
=
spec_pb2
.
ComponentSpec
()
...
@@ -354,15 +349,14 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
...
@@ -354,15 +349,14 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
name: "fixed3" embedding_dim: -1 size: 1
name: "fixed3" embedding_dim: -1 size: 1
}
}
"""
,
component_spec
)
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
self
.
master
,
component_spec
)
# Should not raise errors.
# Should not raise errors.
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
)
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
)
def
testBulkFeatureIdExtractorFailsOnEmbeddedFixedFeature
(
self
):
def
testBulkFeatureIdExtractorFailsOnEmbeddedFixedFeature
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
component_spec
=
spec_pb2
.
ComponentSpec
()
...
@@ -375,10 +369,9 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
...
@@ -375,10 +369,9 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
name: "fixed" embedding_dim: 2 size: 1
name: "fixed" embedding_dim: 2 size: 1
}
}
"""
,
component_spec
)
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
unused_comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
unused_comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
self
.
master
,
component_spec
)
def
testBulkFeatureIdExtractorExtractFocusWithOffset
(
self
):
def
testBulkFeatureIdExtractorExtractFocusWithOffset
(
self
):
path
=
os
.
path
.
join
(
tf
.
test
.
get_temp_dir
(),
'label-map'
)
path
=
os
.
path
.
join
(
tf
.
test
.
get_temp_dir
(),
'label-map'
)
...
@@ -420,67 +413,131 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
...
@@ -420,67 +413,131 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
}
}
"""
%
path
,
master_spec
)
"""
%
path
,
master_spec
)
with
tf
.
Graph
().
as_default
():
corpus
=
_create_fake_corpus
()
corpus
=
_create_fake_corpus
()
corpus
=
tf
.
constant
(
corpus
,
shape
=
[
len
(
corpus
)])
corpus
=
tf
.
constant
(
corpus
,
shape
=
[
len
(
corpus
)])
handle
=
dragnn_ops
.
get_session
(
handle
=
dragnn_ops
.
get_session
(
container
=
'test'
,
container
=
'test'
,
master_spec
=
master_spec
.
SerializeToString
(),
master_spec
=
master_spec
.
SerializeToString
(),
grid_point
=
''
)
grid_point
=
''
)
handle
=
dragnn_ops
.
attach_data_reader
(
handle
,
corpus
)
handle
=
dragnn_ops
.
attach_data_reader
(
handle
,
corpus
)
handle
=
dragnn_ops
.
init_component_data
(
handle
=
dragnn_ops
.
init_component_data
(
handle
,
beam_size
=
1
,
component
=
'test'
)
handle
,
beam_size
=
1
,
component
=
'test'
)
batch_size
=
dragnn_ops
.
batch_size
(
handle
,
component
=
'test'
)
batch_size
=
dragnn_ops
.
batch_size
(
handle
,
component
=
'test'
)
master_state
=
component
.
MasterState
(
handle
,
batch_size
)
master_state
=
component
.
MasterState
(
handle
,
batch_size
)
extractor
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
extractor
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
master_spec
.
component
[
0
])
self
.
master
,
master_spec
.
component
[
0
])
network_state
=
component
.
NetworkState
()
network_state
=
component
.
NetworkState
()
self
.
network_states
[
'test'
]
=
network_state
self
.
network_states
[
'test'
]
=
network_state
handle
=
extractor
.
build_greedy_inference
(
master_state
,
self
.
network_states
)
handle
=
extractor
.
build_greedy_inference
(
master_state
,
focus1
=
network_state
.
activations
[
'focus1'
].
bulk_tensor
self
.
network_states
)
focus2
=
network_state
.
activations
[
'focus2'
].
bulk_tensor
focus1
=
network_state
.
activations
[
'focus1'
].
bulk_tensor
focus3
=
network_state
.
activations
[
'focus3'
].
bulk_tensor
focus2
=
network_state
.
activations
[
'focus2'
].
bulk_tensor
focus3
=
network_state
.
activations
[
'focus3'
].
bulk_tensor
with
self
.
test_session
()
as
sess
:
focus1
,
focus2
,
focus3
=
sess
.
run
([
focus1
,
focus2
,
focus3
])
with
self
.
test_session
()
as
sess
:
tf
.
logging
.
info
(
'focus1=
\n
%s'
,
focus1
)
focus1
,
focus2
,
focus3
=
sess
.
run
([
focus1
,
focus2
,
focus3
])
tf
.
logging
.
info
(
'focus2=
\n
%s'
,
focus2
)
tf
.
logging
.
info
(
'focus1=
\n
%s'
,
focus1
)
tf
.
logging
.
info
(
'focus3=
\n
%s'
,
focus3
)
tf
.
logging
.
info
(
'focus2=
\n
%s'
,
focus2
)
tf
.
logging
.
info
(
'focus3=
\n
%s'
,
focus3
)
self
.
assertAllEqual
(
focus1
,
[[
0
],
[
-
1
],
[
-
1
],
[
-
1
],
self
.
assertAllEqual
(
[
0
],
[
1
],
[
-
1
],
[
-
1
],
focus1
,
[
0
],
[
1
],
[
2
],
[
-
1
],
[[
0
],
[
-
1
],
[
-
1
],
[
-
1
],
[
0
],
[
1
],
[
2
],
[
3
]])
# pyformat: disable
[
0
],
[
1
],
[
-
1
],
[
-
1
],
[
0
],
[
1
],
[
2
],
[
-
1
],
self
.
assertAllEqual
(
focus2
,
[
0
],
[
1
],
[
2
],
[
3
]])
[[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
1
],
[
-
1
],
[
-
1
],
[
-
1
],
self
.
assertAllEqual
(
[
1
],
[
2
],
[
-
1
],
[
-
1
],
focus2
,
[
1
],
[
2
],
[
3
],
[
-
1
]])
# pyformat: disable
[[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
1
],
[
-
1
],
[
-
1
],
[
-
1
],
self
.
assertAllEqual
(
focus3
,
[
1
],
[
2
],
[
-
1
],
[
-
1
],
[[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
1
],
[
2
],
[
3
],
[
-
1
]])
[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
2
],
[
-
1
],
[
-
1
],
[
-
1
],
self
.
assertAllEqual
(
[
2
],
[
3
],
[
-
1
],
[
-
1
]])
# pyformat: disable
focus3
,
[[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
2
],
[
-
1
],
[
-
1
],
[
-
1
],
[
2
],
[
3
],
[
-
1
],
[
-
1
]])
def
testBuildLossFailsOnNoExamples
(
self
):
def
testBuildLossFailsOnNoExamples
(
self
):
with
tf
.
Graph
().
as_default
():
logits
=
tf
.
constant
([[
0.5
],
[
-
0.5
],
[
0.5
],
[
-
0.5
]])
logits
=
tf
.
constant
([[
0.5
],
[
-
0.5
],
[
0.5
],
[
-
0.5
]])
gold
=
tf
.
constant
([
-
1
,
-
1
,
-
1
,
-
1
])
gold
=
tf
.
constant
([
-
1
,
-
1
,
-
1
,
-
1
])
result
=
bulk_component
.
build_cross_entropy_loss
(
logits
,
gold
)
result
=
bulk_component
.
build_cross_entropy_loss
(
logits
,
gold
)
# Expect loss computation to generate a runtime error due to the gold
# Expect loss computation to generate a runtime error due to the gold
# tensor containing no valid examples.
# tensor containing no valid examples.
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
with
self
.
assertRaises
(
tf
.
errors
.
InvalidArgumentError
):
with
self
.
assertRaises
(
tf
.
errors
.
InvalidArgumentError
):
sess
.
run
(
result
)
sess
.
run
(
result
)
def
testPreCreateCalledBeforeCreate
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
text_format
.
Parse
(
"""
name: "test"
network_unit {
registered_name: "IdentityNetwork"
}
"""
,
component_spec
)
class
AssertPreCreateBeforeCreateNetwork
(
network_units
.
NetworkUnitInterface
):
"""Mock that asserts that .create() is called before .pre_create()."""
def
__init__
(
self
,
comp
,
test_fixture
):
super
(
AssertPreCreateBeforeCreateNetwork
,
self
).
__init__
(
comp
)
self
.
_test_fixture
=
test_fixture
self
.
_pre_create_called
=
False
def
get_logits
(
self
,
network_tensors
):
return
tf
.
zeros
([
2
,
1
],
dtype
=
tf
.
float32
)
def
pre_create
(
self
,
*
unused_args
):
self
.
_pre_create_called
=
True
def
create
(
self
,
*
unused_args
,
**
unuesd_kwargs
):
self
.
_test_fixture
.
assertTrue
(
self
.
_pre_create_called
)
return
[]
builder
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
builder
.
network
=
AssertPreCreateBeforeCreateNetwork
(
builder
,
self
)
builder
.
build_greedy_training
(
component
.
MasterState
([
'foo'
,
'bar'
],
2
),
self
.
network_states
)
self
.
setUp
()
builder
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
builder
.
network
=
AssertPreCreateBeforeCreateNetwork
(
builder
,
self
)
builder
.
build_greedy_inference
(
component
.
MasterState
([
'foo'
,
'bar'
],
2
),
self
.
network_states
)
self
.
setUp
()
builder
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
builder
.
network
=
AssertPreCreateBeforeCreateNetwork
(
builder
,
self
)
builder
.
build_greedy_training
(
component
.
MasterState
([
'foo'
,
'bar'
],
2
),
self
.
network_states
)
self
.
setUp
()
builder
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
builder
.
network
=
AssertPreCreateBeforeCreateNetwork
(
builder
,
self
)
builder
.
build_greedy_inference
(
component
.
MasterState
([
'foo'
,
'bar'
],
2
),
self
.
network_states
)
self
.
setUp
()
builder
=
bulk_component
.
BulkAnnotatorComponentBuilder
(
self
.
master
,
component_spec
)
builder
.
network
=
AssertPreCreateBeforeCreateNetwork
(
builder
,
self
)
builder
.
build_greedy_training
(
component
.
MasterState
([
'foo'
,
'bar'
],
2
),
self
.
network_states
)
self
.
setUp
()
builder
=
bulk_component
.
BulkAnnotatorComponentBuilder
(
self
.
master
,
component_spec
)
builder
.
network
=
AssertPreCreateBeforeCreateNetwork
(
builder
,
self
)
builder
.
build_greedy_inference
(
component
.
MasterState
([
'foo'
,
'bar'
],
2
),
self
.
network_states
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
googletest
.
main
()
googletest
.
main
()
research/syntaxnet/dragnn/python/component.py
View file @
80178fc6
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Builds a DRAGNN graph for local training."""
"""Builds a DRAGNN graph for local training."""
from
abc
import
ABCMeta
from
abc
import
ABCMeta
...
@@ -21,12 +20,79 @@ from abc import abstractmethod
...
@@ -21,12 +20,79 @@ from abc import abstractmethod
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.platform
import
tf_logging
as
logging
from
tensorflow.python.platform
import
tf_logging
as
logging
from
dragnn.protos
import
export_pb2
from
dragnn.python
import
dragnn_ops
from
dragnn.python
import
dragnn_ops
from
dragnn.python
import
network_units
from
dragnn.python
import
network_units
from
dragnn.python
import
runtime_support
from
syntaxnet.util
import
check
from
syntaxnet.util
import
check
from
syntaxnet.util
import
registry
from
syntaxnet.util
import
registry
def
build_softmax_cross_entropy_loss
(
logits
,
gold
):
"""Builds softmax cross entropy loss."""
# A gold label > -1 determines that the sentence is still
# in a valid state. Otherwise, the sentence has ended.
#
# We add only the valid sentences to the loss, in the following way:
# 1. We compute 'valid_ix', the indices in gold that contain
# valid oracle actions.
# 2. We compute the cost function by comparing logits and gold
# only for the valid indices.
valid
=
tf
.
greater
(
gold
,
-
1
)
valid_ix
=
tf
.
reshape
(
tf
.
where
(
valid
),
[
-
1
])
valid_gold
=
tf
.
gather
(
gold
,
valid_ix
)
valid_logits
=
tf
.
gather
(
logits
,
valid_ix
)
cost
=
tf
.
reduce_sum
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
tf
.
cast
(
valid_gold
,
tf
.
int64
),
logits
=
valid_logits
,
name
=
'sparse_softmax_cross_entropy_with_logits'
))
correct
=
tf
.
reduce_sum
(
tf
.
to_int32
(
tf
.
nn
.
in_top_k
(
valid_logits
,
valid_gold
,
1
)))
total
=
tf
.
size
(
valid_gold
)
return
cost
,
correct
,
total
,
valid_logits
,
valid_gold
def
build_sigmoid_cross_entropy_loss
(
logits
,
gold
,
indices
,
probs
):
"""Builds sigmoid cross entropy loss."""
# Filter out entries where gold <= -1, which are batch padding entries.
valid
=
tf
.
greater
(
gold
,
-
1
)
valid_ix
=
tf
.
reshape
(
tf
.
where
(
valid
),
[
-
1
])
valid_gold
=
tf
.
gather
(
gold
,
valid_ix
)
valid_indices
=
tf
.
gather
(
indices
,
valid_ix
)
valid_probs
=
tf
.
gather
(
probs
,
valid_ix
)
# NB: tf.gather_nd() raises an error on CPU for out-of-bounds indices. That's
# why we need to filter out the gold=-1 batch padding above.
valid_pairs
=
tf
.
stack
([
valid_indices
,
valid_gold
],
axis
=
1
)
valid_logits
=
tf
.
gather_nd
(
logits
,
valid_pairs
)
cost
=
tf
.
reduce_sum
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
valid_probs
,
logits
=
valid_logits
,
name
=
'sigmoid_cross_entropy_with_logits'
))
gold_bool
=
valid_probs
>
0.5
predicted_bool
=
valid_logits
>
0.0
total
=
tf
.
size
(
gold_bool
)
with
tf
.
control_dependencies
([
tf
.
assert_equal
(
total
,
tf
.
size
(
predicted_bool
),
name
=
'num_predicted_gold_mismatch'
)
]):
agreement_bool
=
tf
.
logical_not
(
tf
.
logical_xor
(
gold_bool
,
predicted_bool
))
correct
=
tf
.
reduce_sum
(
tf
.
to_int32
(
agreement_bool
))
cost
.
set_shape
([])
correct
.
set_shape
([])
total
.
set_shape
([])
return
cost
,
correct
,
total
,
gold
class
NetworkState
(
object
):
class
NetworkState
(
object
):
"""Simple utility to manage the state of a DRAGNN network.
"""Simple utility to manage the state of a DRAGNN network.
...
@@ -69,6 +135,13 @@ class ComponentBuilderBase(object):
...
@@ -69,6 +135,13 @@ class ComponentBuilderBase(object):
As part of the specification, ComponentBuilder will wrap an underlying
As part of the specification, ComponentBuilder will wrap an underlying
NetworkUnit which generates the actual network layout.
NetworkUnit which generates the actual network layout.
Attributes:
master: dragnn.MasterBuilder that owns this component.
num_actions: Number of actions in the transition system.
name: Name of this component.
spec: dragnn.ComponentSpec that configures this component.
moving_average: True if moving-average parameters should be used.
"""
"""
__metaclass__
=
ABCMeta
# required for @abstractmethod
__metaclass__
=
ABCMeta
# required for @abstractmethod
...
@@ -96,16 +169,23 @@ class ComponentBuilderBase(object):
...
@@ -96,16 +169,23 @@ class ComponentBuilderBase(object):
# Extract component attributes before make_network(), so the network unit
# Extract component attributes before make_network(), so the network unit
# can access them.
# can access them.
self
.
_attrs
=
{}
self
.
_attrs
=
{}
global_attr_defaults
=
{
'locally_normalize'
:
False
,
'output_as_probabilities'
:
False
}
if
attr_defaults
:
if
attr_defaults
:
self
.
_attrs
=
network_units
.
get_attrs_with_defaults
(
global_attr_defaults
.
update
(
attr_defaults
)
self
.
spec
.
component_builder
.
parameters
,
attr_defaults
)
self
.
_attrs
=
network_units
.
get_attrs_with_defaults
(
self
.
spec
.
component_builder
.
parameters
,
global_attr_defaults
)
do_local_norm
=
self
.
_attrs
[
'locally_normalize'
]
self
.
_output_as_probabilities
=
self
.
_attrs
[
'output_as_probabilities'
]
with
tf
.
variable_scope
(
self
.
name
):
with
tf
.
variable_scope
(
self
.
name
):
self
.
training_beam_size
=
tf
.
constant
(
self
.
training_beam_size
=
tf
.
constant
(
self
.
spec
.
training_beam_size
,
name
=
'TrainingBeamSize'
)
self
.
spec
.
training_beam_size
,
name
=
'TrainingBeamSize'
)
self
.
inference_beam_size
=
tf
.
constant
(
self
.
inference_beam_size
=
tf
.
constant
(
self
.
spec
.
inference_beam_size
,
name
=
'InferenceBeamSize'
)
self
.
spec
.
inference_beam_size
,
name
=
'InferenceBeamSize'
)
self
.
locally_normalize
=
tf
.
constant
(
False
,
name
=
'LocallyNormalize'
)
self
.
locally_normalize
=
tf
.
constant
(
do_local_norm
,
name
=
'LocallyNormalize'
)
self
.
_step
=
tf
.
get_variable
(
self
.
_step
=
tf
.
get_variable
(
'step'
,
[],
initializer
=
tf
.
zeros_initializer
(),
dtype
=
tf
.
int32
)
'step'
,
[],
initializer
=
tf
.
zeros_initializer
(),
dtype
=
tf
.
int32
)
self
.
_total
=
tf
.
get_variable
(
self
.
_total
=
tf
.
get_variable
(
...
@@ -120,6 +200,9 @@ class ComponentBuilderBase(object):
...
@@ -120,6 +200,9 @@ class ComponentBuilderBase(object):
decay
=
self
.
master
.
hyperparams
.
average_weight
,
num_updates
=
self
.
_step
)
decay
=
self
.
master
.
hyperparams
.
average_weight
,
num_updates
=
self
.
_step
)
self
.
avg_ops
=
[
self
.
moving_average
.
apply
(
self
.
network
.
params
)]
self
.
avg_ops
=
[
self
.
moving_average
.
apply
(
self
.
network
.
params
)]
# Used to export the cell; see add_cell_input() and add_cell_output().
self
.
_cell_subgraph_spec
=
export_pb2
.
CellSubgraphSpec
()
def
make_network
(
self
,
network_unit
):
def
make_network
(
self
,
network_unit
):
"""Makes a NetworkUnitInterface object based on the network_unit spec.
"""Makes a NetworkUnitInterface object based on the network_unit spec.
...
@@ -276,7 +359,7 @@ class ComponentBuilderBase(object):
...
@@ -276,7 +359,7 @@ class ComponentBuilderBase(object):
Returns:
Returns:
tf.Variable object corresponding to original or averaged version.
tf.Variable object corresponding to original or averaged version.
"""
"""
if
var_params
:
if
var_params
is
not
None
:
var_name
=
var_params
.
name
var_name
=
var_params
.
name
else
:
else
:
check
.
NotNone
(
var_name
,
'specify at least one of var_name or var_params'
)
check
.
NotNone
(
var_name
,
'specify at least one of var_name or var_params'
)
...
@@ -341,6 +424,79 @@ class ComponentBuilderBase(object):
...
@@ -341,6 +424,79 @@ class ComponentBuilderBase(object):
"""Returns the value of the component attribute with the |name|."""
"""Returns the value of the component attribute with the |name|."""
return
self
.
_attrs
[
name
]
return
self
.
_attrs
[
name
]
def
has_attr
(
self
,
name
):
"""Checks whether a component attribute with the given |name| exists.
Arguments:
name: attribute name
Returns:
'true' if the name exists and 'false' otherwise.
"""
return
name
in
self
.
_attrs
def
_add_runtime_hooks
(
self
):
"""Adds "hook" nodes to the graph for use by the runtime, if enabled.
Does nothing if master.build_runtime_graph is False. Subclasses should call
this at the end of build_*_inference(). For details on the runtime hooks,
see runtime_support.py.
"""
if
self
.
master
.
build_runtime_graph
:
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
runtime_support
.
add_hooks
(
self
,
self
.
_cell_subgraph_spec
)
self
.
_cell_subgraph_spec
=
None
# prevent further exports
def
add_cell_input
(
self
,
dtype
,
shape
,
name
,
input_type
=
'TYPE_FEATURE'
):
"""Adds an input to the current CellSubgraphSpec.
Creates a tf.placeholder() with the given |dtype| and |shape|, adds it as a
cell input with the |name| and |input_type|, and returns the placeholder to
be used in place of the actual input tensor.
Args:
dtype: DType of the cell input.
shape: Static shape of the cell input.
name: Logical name of the cell input.
input_type: Name of the appropriate CellSubgraphSpec.Input.Type enum.
Returns:
A tensor to use in place of the actual input tensor.
Raises:
TypeError: If the |shape| is the wrong type.
RuntimeError: If the cell has already been exported.
"""
if
not
(
isinstance
(
shape
,
list
)
and
all
(
isinstance
(
dim
,
int
)
for
dim
in
shape
)):
raise
TypeError
(
'shape must be a list of int'
)
if
not
self
.
_cell_subgraph_spec
:
raise
RuntimeError
(
'already exported a CellSubgraphSpec'
)
with
tf
.
name_scope
(
None
):
tensor
=
tf
.
placeholder
(
dtype
,
shape
,
name
=
'{}/INPUT/{}'
.
format
(
self
.
name
,
name
))
self
.
_cell_subgraph_spec
.
input
.
add
(
name
=
name
,
tensor
=
tensor
.
name
,
type
=
export_pb2
.
CellSubgraphSpec
.
Input
.
Type
.
Value
(
input_type
))
return
tensor
def
add_cell_output
(
self
,
tensor
,
name
):
"""Adds an output to the current CellSubgraphSpec.
Args:
tensor: Tensor to add as a cell output.
name: Logical name of the cell output.
Raises:
RuntimeError: If the cell has already been exported.
"""
if
not
self
.
_cell_subgraph_spec
:
raise
RuntimeError
(
'already exported a CellSubgraphSpec'
)
self
.
_cell_subgraph_spec
.
output
.
add
(
name
=
name
,
tensor
=
tensor
.
name
)
def
update_tensor_arrays
(
network_tensors
,
arrays
):
def
update_tensor_arrays
(
network_tensors
,
arrays
):
"""Updates a list of tensor arrays from the network's output tensors.
"""Updates a list of tensor arrays from the network's output tensors.
...
@@ -370,6 +526,18 @@ class DynamicComponentBuilder(ComponentBuilderBase):
...
@@ -370,6 +526,18 @@ class DynamicComponentBuilder(ComponentBuilderBase):
so fixed and linked features can be recurrent.
so fixed and linked features can be recurrent.
"""
"""
def
__init__
(
self
,
master
,
component_spec
):
"""Initializes the DynamicComponentBuilder from specifications.
Args:
master: dragnn.MasterBuilder object.
component_spec: dragnn.ComponentSpec proto to be built.
"""
super
(
DynamicComponentBuilder
,
self
).
__init__
(
master
,
component_spec
,
attr_defaults
=
{
'loss_function'
:
'softmax_cross_entropy'
})
def
build_greedy_training
(
self
,
state
,
network_states
):
def
build_greedy_training
(
self
,
state
,
network_states
):
"""Builds a training loop for this component.
"""Builds a training loop for this component.
...
@@ -392,9 +560,10 @@ class DynamicComponentBuilder(ComponentBuilderBase):
...
@@ -392,9 +560,10 @@ class DynamicComponentBuilder(ComponentBuilderBase):
# Add 0 to training_beam_size to disable eager static evaluation.
# Add 0 to training_beam_size to disable eager static evaluation.
# This is possible because tensorflow's constant_value does not
# This is possible because tensorflow's constant_value does not
# propagate arithmetic operations.
# propagate arithmetic operations.
with
tf
.
control_dependencies
(
[
with
tf
.
control_dependencies
(
tf
.
assert_equal
(
self
.
training_beam_size
+
0
,
1
)]):
[
tf
.
assert_equal
(
self
.
training_beam_size
+
0
,
1
)]):
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
self
.
network
.
pre_create
(
stride
)
cost
=
tf
.
constant
(
0.
)
cost
=
tf
.
constant
(
0.
)
correct
=
tf
.
constant
(
0
)
correct
=
tf
.
constant
(
0
)
...
@@ -416,40 +585,35 @@ class DynamicComponentBuilder(ComponentBuilderBase):
...
@@ -416,40 +585,35 @@ class DynamicComponentBuilder(ComponentBuilderBase):
# Every layer is written to a TensorArray, so that it can be backprop'd.
# Every layer is written to a TensorArray, so that it can be backprop'd.
next_arrays
=
update_tensor_arrays
(
network_tensors
,
arrays
)
next_arrays
=
update_tensor_arrays
(
network_tensors
,
arrays
)
loss_function
=
self
.
attr
(
'loss_function'
)
with
tf
.
control_dependencies
([
x
.
flow
for
x
in
next_arrays
]):
with
tf
.
control_dependencies
([
x
.
flow
for
x
in
next_arrays
]):
with
tf
.
name_scope
(
'compute_loss'
):
with
tf
.
name_scope
(
'compute_loss'
):
# A gold label > -1 determines that the sentence is still
# in a valid state. Otherwise, the sentence has ended.
#
# We add only the valid sentences to the loss, in the following way:
# 1. We compute 'valid_ix', the indices in gold that contain
# valid oracle actions.
# 2. We compute the cost function by comparing logits and gold
# only for the valid indices.
gold
=
dragnn_ops
.
emit_oracle_labels
(
handle
,
component
=
self
.
name
)
gold
.
set_shape
([
None
])
valid
=
tf
.
greater
(
gold
,
-
1
)
valid_ix
=
tf
.
reshape
(
tf
.
where
(
valid
),
[
-
1
])
gold
=
tf
.
gather
(
gold
,
valid_ix
)
logits
=
self
.
network
.
get_logits
(
network_tensors
)
logits
=
self
.
network
.
get_logits
(
network_tensors
)
logits
=
tf
.
gather
(
logits
,
valid_ix
)
if
loss_function
==
'softmax_cross_entropy'
:
gold
=
dragnn_ops
.
emit_oracle_labels
(
handle
,
component
=
self
.
name
)
cost
+=
tf
.
reduce_sum
(
new_cost
,
new_correct
,
new_total
,
valid_logits
,
valid_gold
=
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
build_softmax_cross_entropy_loss
(
logits
,
gold
))
labels
=
tf
.
cast
(
gold
,
tf
.
int64
),
logits
=
logits
))
if
(
self
.
eligible_for_self_norm
and
if
(
self
.
eligible_for_self_norm
and
self
.
master
.
hyperparams
.
self_norm_alpha
>
0
):
self
.
master
.
hyperparams
.
self_norm_alpha
>
0
):
log_z
=
tf
.
reduce_logsumexp
(
valid_logits
,
[
1
])
log_z
=
tf
.
reduce_logsumexp
(
logits
,
[
1
])
new_cost
+=
(
self
.
master
.
hyperparams
.
self_norm_alpha
*
cost
+=
(
self
.
master
.
hyperparams
.
self_norm_alpha
*
tf
.
nn
.
l2_loss
(
log_z
))
tf
.
nn
.
l2_loss
(
log_z
))
elif
loss_function
==
'sigmoid_cross_entropy'
:
indices
,
gold
,
probs
=
(
correct
+=
tf
.
reduce_sum
(
dragnn_ops
.
emit_oracle_labels_and_probabilities
(
tf
.
to_int32
(
tf
.
nn
.
in_top_k
(
logits
,
gold
,
1
)))
handle
,
component
=
self
.
name
))
total
+=
tf
.
size
(
gold
)
new_cost
,
new_correct
,
new_total
,
valid_gold
=
(
build_sigmoid_cross_entropy_loss
(
logits
,
gold
,
indices
,
with
tf
.
control_dependencies
([
cost
,
correct
,
total
,
gold
]):
probs
))
else
:
RuntimeError
(
"Unknown loss function '%s'"
%
loss_function
)
cost
+=
new_cost
correct
+=
new_correct
total
+=
new_total
with
tf
.
control_dependencies
([
cost
,
correct
,
total
,
valid_gold
]):
handle
=
dragnn_ops
.
advance_from_oracle
(
handle
,
component
=
self
.
name
)
handle
=
dragnn_ops
.
advance_from_oracle
(
handle
,
component
=
self
.
name
)
return
[
handle
,
cost
,
correct
,
total
]
+
next_arrays
return
[
handle
,
cost
,
correct
,
total
]
+
next_arrays
...
@@ -480,6 +644,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
...
@@ -480,6 +644,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
# Normalize the objective by the total # of steps taken.
# Normalize the objective by the total # of steps taken.
# Note: Total could be zero by a number of reasons, including:
# Note: Total could be zero by a number of reasons, including:
# * Oracle labels not being emitted.
# * Oracle labels not being emitted.
# * All oracle labels for a batch are unknown (-1).
# * No steps being taken if component is terminal at the start of a batch.
# * No steps being taken if component is terminal at the start of a batch.
with
tf
.
control_dependencies
([
tf
.
assert_greater
(
total
,
0
)]):
with
tf
.
control_dependencies
([
tf
.
assert_greater
(
total
,
0
)]):
cost
/=
tf
.
to_float
(
total
)
cost
/=
tf
.
to_float
(
total
)
...
@@ -511,6 +676,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
...
@@ -511,6 +676,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
else
:
else
:
stride
=
state
.
current_batch_size
*
self
.
inference_beam_size
stride
=
state
.
current_batch_size
*
self
.
inference_beam_size
self
.
network
.
pre_create
(
stride
)
def
cond
(
handle
,
*
_
):
def
cond
(
handle
,
*
_
):
all_final
=
dragnn_ops
.
emit_all_final
(
handle
,
component
=
self
.
name
)
all_final
=
dragnn_ops
.
emit_all_final
(
handle
,
component
=
self
.
name
)
...
@@ -559,6 +725,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
...
@@ -559,6 +725,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
for
index
,
layer
in
enumerate
(
self
.
network
.
layers
):
for
index
,
layer
in
enumerate
(
self
.
network
.
layers
):
network_state
.
activations
[
layer
.
name
]
=
network_units
.
StoredActivations
(
network_state
.
activations
[
layer
.
name
]
=
network_units
.
StoredActivations
(
array
=
arrays
[
index
])
array
=
arrays
[
index
])
self
.
_add_runtime_hooks
()
with
tf
.
control_dependencies
([
x
.
flow
for
x
in
arrays
]):
with
tf
.
control_dependencies
([
x
.
flow
for
x
in
arrays
]):
return
tf
.
identity
(
state
.
handle
)
return
tf
.
identity
(
state
.
handle
)
...
@@ -587,7 +754,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
...
@@ -587,7 +754,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
fixed_embeddings
=
[]
fixed_embeddings
=
[]
for
channel_id
,
feature_spec
in
enumerate
(
self
.
spec
.
fixed_feature
):
for
channel_id
,
feature_spec
in
enumerate
(
self
.
spec
.
fixed_feature
):
fixed_embedding
=
network_units
.
fixed_feature_lookup
(
fixed_embedding
=
network_units
.
fixed_feature_lookup
(
self
,
state
,
channel_id
,
stride
)
self
,
state
,
channel_id
,
stride
,
during_training
)
if
feature_spec
.
is_constant
:
if
feature_spec
.
is_constant
:
fixed_embedding
.
tensor
=
tf
.
stop_gradient
(
fixed_embedding
.
tensor
)
fixed_embedding
.
tensor
=
tf
.
stop_gradient
(
fixed_embedding
.
tensor
)
fixed_embeddings
.
append
(
fixed_embedding
)
fixed_embeddings
.
append
(
fixed_embedding
)
...
@@ -633,6 +800,12 @@ class DynamicComponentBuilder(ComponentBuilderBase):
...
@@ -633,6 +800,12 @@ class DynamicComponentBuilder(ComponentBuilderBase):
else
:
else
:
attention_tensor
=
None
attention_tensor
=
None
return
self
.
network
.
create
(
fixed_embeddings
,
linked_embeddings
,
tensors
=
self
.
network
.
create
(
fixed_embeddings
,
linked_embeddings
,
context_tensor_arrays
,
attention_tensor
,
context_tensor_arrays
,
attention_tensor
,
during_training
)
during_training
)
if
self
.
master
.
build_runtime_graph
:
for
index
,
layer
in
enumerate
(
self
.
network
.
layers
):
self
.
add_cell_output
(
tensors
[
index
],
layer
.
name
)
return
tensors
research/syntaxnet/dragnn/python/component_test.py
0 → 100644
View file @
80178fc6
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for component.py.
"""
import
tensorflow
as
tf
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
googletest
from
google.protobuf
import
text_format
from
dragnn.protos
import
spec_pb2
from
dragnn.python
import
component
class
MockNetworkUnit
(
object
):
def
get_layer_size
(
self
,
unused_layer_name
):
return
64
class
MockComponent
(
object
):
def
__init__
(
self
):
self
.
name
=
'mock'
self
.
network
=
MockNetworkUnit
()
class
MockMaster
(
object
):
def
__init__
(
self
):
self
.
spec
=
spec_pb2
.
MasterSpec
()
self
.
hyperparams
=
spec_pb2
.
GridPoint
()
self
.
lookup_component
=
{
'mock'
:
MockComponent
()}
self
.
build_runtime_graph
=
False
class
ComponentTest
(
test_util
.
TensorFlowTestCase
):
def
setUp
(
self
):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf
.
reset_default_graph
()
self
.
master
=
MockMaster
()
self
.
master_state
=
component
.
MasterState
(
handle
=
tf
.
constant
([
'foo'
,
'bar'
]),
current_batch_size
=
2
)
self
.
network_states
=
{
'mock'
:
component
.
NetworkState
(),
'test'
:
component
.
NetworkState
(),
}
def
testSoftmaxCrossEntropyLoss
(
self
):
logits
=
tf
.
constant
([[
0.0
,
2.0
,
-
1.0
],
[
-
5.0
,
1.0
,
-
1.0
],
[
3.0
,
1.0
,
-
2.0
]])
# pyformat: disable
gold_labels
=
tf
.
constant
([
1
,
-
1
,
1
])
cost
,
correct
,
total
,
logits
,
gold_labels
=
(
component
.
build_softmax_cross_entropy_loss
(
logits
,
gold_labels
))
with
self
.
test_session
()
as
sess
:
cost
,
correct
,
total
,
logits
,
gold_labels
=
(
sess
.
run
([
cost
,
correct
,
total
,
logits
,
gold_labels
]))
# Cost = -2 + ln(1 + exp(2) + exp(-1))
# -1 + ln(exp(3) + exp(1) + exp(-2))
self
.
assertAlmostEqual
(
cost
,
2.3027
,
4
)
self
.
assertEqual
(
correct
,
1
)
self
.
assertEqual
(
total
,
2
)
# Entries corresponding to gold labels equal to -1 are skipped.
self
.
assertAllEqual
(
logits
,
[[
0.0
,
2.0
,
-
1.0
],
[
3.0
,
1.0
,
-
2.0
]])
self
.
assertAllEqual
(
gold_labels
,
[
1
,
1
])
def
testSigmoidCrossEntropyLoss
(
self
):
indices
=
tf
.
constant
([
0
,
0
,
1
])
gold_labels
=
tf
.
constant
([
0
,
1
,
2
])
probs
=
tf
.
constant
([
0.6
,
0.7
,
0.2
])
logits
=
tf
.
constant
([[
0.9
,
-
0.3
,
0.1
],
[
-
0.5
,
0.4
,
2.0
]])
cost
,
correct
,
total
,
gold_labels
=
(
component
.
build_sigmoid_cross_entropy_loss
(
logits
,
gold_labels
,
indices
,
probs
))
with
self
.
test_session
()
as
sess
:
cost
,
correct
,
total
,
gold_labels
=
(
sess
.
run
([
cost
,
correct
,
total
,
gold_labels
]))
# The cost corresponding to the three entries is, respectively,
# 0.7012, 0.7644, and 1.7269. Each of them is computed using the formula
# -prob_i * log(sigmoid(logit_i)) - (1-prob_i) * log(1-sigmoid(logit_i))
self
.
assertAlmostEqual
(
cost
,
3.1924
,
4
)
self
.
assertEqual
(
correct
,
1
)
self
.
assertEqual
(
total
,
3
)
self
.
assertAllEqual
(
gold_labels
,
[
0
,
1
,
2
])
def
testGraphConstruction
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
text_format
.
Parse
(
"""
name: "test"
network_unit {
registered_name: "IdentityNetwork"
}
fixed_feature {
name: "fixed" embedding_dim: 32 size: 1
}
component_builder {
registered_name: "component.DynamicComponentBuilder"
}
"""
,
component_spec
)
comp
=
component
.
DynamicComponentBuilder
(
self
.
master
,
component_spec
)
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
def
testGraphConstructionWithSigmoidLoss
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
text_format
.
Parse
(
"""
name: "test"
network_unit {
registered_name: "IdentityNetwork"
}
fixed_feature {
name: "fixed" embedding_dim: 32 size: 1
}
component_builder {
registered_name: "component.DynamicComponentBuilder"
parameters {
key: "loss_function"
value: "sigmoid_cross_entropy"
}
}
"""
,
component_spec
)
comp
=
component
.
DynamicComponentBuilder
(
self
.
master
,
component_spec
)
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
# Check that the loss op is present.
op_names
=
[
op
.
name
for
op
in
tf
.
get_default_graph
().
get_operations
()]
self
.
assertTrue
(
'train_test/compute_loss/'
'sigmoid_cross_entropy_with_logits'
in
op_names
)
if
__name__
==
'__main__'
:
googletest
.
main
()
research/syntaxnet/dragnn/python/digraph_ops.py
View file @
80178fc6
...
@@ -15,6 +15,10 @@
...
@@ -15,6 +15,10 @@
"""TensorFlow ops for directed graphs."""
"""TensorFlow ops for directed graphs."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
from
syntaxnet.util
import
check
from
syntaxnet.util
import
check
...
@@ -150,7 +154,7 @@ def ArcSourcePotentialsFromTokens(tokens, weights):
...
@@ -150,7 +154,7 @@ def ArcSourcePotentialsFromTokens(tokens, weights):
return
sources_bxnxn
return
sources_bxnxn
def
RootPotentialsFromTokens
(
root
,
tokens
,
weights
):
def
RootPotentialsFromTokens
(
root
,
tokens
,
weights
_arc
,
weights_source
):
r
"""Returns root selection potentials computed from tokens and weights.
r
"""Returns root selection potentials computed from tokens and weights.
For each batch of token activations, computes a scalar potential for each root
For each batch of token activations, computes a scalar potential for each root
...
@@ -162,7 +166,8 @@ def RootPotentialsFromTokens(root, tokens, weights):
...
@@ -162,7 +166,8 @@ def RootPotentialsFromTokens(root, tokens, weights):
Args:
Args:
root: [S] vector of activations for the artificial root token.
root: [S] vector of activations for the artificial root token.
tokens: [B,N,T] tensor of batched activations for root tokens.
tokens: [B,N,T] tensor of batched activations for root tokens.
weights: [S,T] matrix of weights.
weights_arc: [S,T] matrix of weights.
weights_source: [S] vector of weights.
B,N may be statically-unknown, but S,T must be statically-known. The dtype
B,N may be statically-unknown, but S,T must be statically-known. The dtype
of all arguments must be compatible.
of all arguments must be compatible.
...
@@ -174,25 +179,30 @@ def RootPotentialsFromTokens(root, tokens, weights):
...
@@ -174,25 +179,30 @@ def RootPotentialsFromTokens(root, tokens, weights):
# All arguments must have statically-known rank.
# All arguments must have statically-known rank.
check
.
Eq
(
root
.
get_shape
().
ndims
,
1
,
'root must be a vector'
)
check
.
Eq
(
root
.
get_shape
().
ndims
,
1
,
'root must be a vector'
)
check
.
Eq
(
tokens
.
get_shape
().
ndims
,
3
,
'tokens must be rank 3'
)
check
.
Eq
(
tokens
.
get_shape
().
ndims
,
3
,
'tokens must be rank 3'
)
check
.
Eq
(
weights
.
get_shape
().
ndims
,
2
,
'weights must be a matrix'
)
check
.
Eq
(
weights_arc
.
get_shape
().
ndims
,
2
,
'weights_arc must be a matrix'
)
check
.
Eq
(
weights_source
.
get_shape
().
ndims
,
1
,
'weights_source must be a vector'
)
# All activation dimensions must be statically-known.
# All activation dimensions must be statically-known.
num_source_activations
=
weights
.
get_shape
().
as_list
()[
0
]
num_source_activations
=
weights
_arc
.
get_shape
().
as_list
()[
0
]
num_target_activations
=
weights
.
get_shape
().
as_list
()[
1
]
num_target_activations
=
weights
_arc
.
get_shape
().
as_list
()[
1
]
check
.
NotNone
(
num_source_activations
,
'unknown source activation dimension'
)
check
.
NotNone
(
num_source_activations
,
'unknown source activation dimension'
)
check
.
NotNone
(
num_target_activations
,
'unknown target activation dimension'
)
check
.
NotNone
(
num_target_activations
,
'unknown target activation dimension'
)
check
.
Eq
(
root
.
get_shape
().
as_list
()[
0
],
num_source_activations
,
check
.
Eq
(
root
.
get_shape
().
as_list
()[
0
],
num_source_activations
,
'dimension mismatch between weights and root'
)
'dimension mismatch between weights
_arc
and root'
)
check
.
Eq
(
tokens
.
get_shape
().
as_list
()[
2
],
num_target_activations
,
check
.
Eq
(
tokens
.
get_shape
().
as_list
()[
2
],
num_target_activations
,
'dimension mismatch between weights and tokens'
)
'dimension mismatch between weights_arc and tokens'
)
check
.
Eq
(
weights_source
.
get_shape
().
as_list
()[
0
],
num_source_activations
,
'dimension mismatch between weights_arc and weights_source'
)
# All arguments must share the same type.
# All arguments must share the same type.
check
.
Same
([
weights
.
dtype
.
base_dtype
,
check
.
Same
([
root
.
dtype
.
base_dtype
,
weights_arc
.
dtype
.
base_dtype
,
weights_source
.
dtype
.
base_dtype
,
tokens
.
dtype
.
base_dtype
],
root
.
dtype
.
base_dtype
,
tokens
.
dtype
.
base_dtype
'dtype mismatch'
)
],
'dtype mismatch'
)
root_1xs
=
tf
.
expand_dims
(
root
,
0
)
root_1xs
=
tf
.
expand_dims
(
root
,
0
)
weights_source_sx1
=
tf
.
expand_dims
(
weights_source
,
1
)
tokens_shape
=
tf
.
shape
(
tokens
)
tokens_shape
=
tf
.
shape
(
tokens
)
batch_size
=
tokens_shape
[
0
]
batch_size
=
tokens_shape
[
0
]
...
@@ -200,9 +210,12 @@ def RootPotentialsFromTokens(root, tokens, weights):
...
@@ -200,9 +210,12 @@ def RootPotentialsFromTokens(root, tokens, weights):
# Flatten out the batch dimension so we can use a couple big matmuls.
# Flatten out the batch dimension so we can use a couple big matmuls.
tokens_bnxt
=
tf
.
reshape
(
tokens
,
[
-
1
,
num_target_activations
])
tokens_bnxt
=
tf
.
reshape
(
tokens
,
[
-
1
,
num_target_activations
])
weights_targets_bnxs
=
tf
.
matmul
(
tokens_bnxt
,
weights
,
transpose_b
=
True
)
weights_targets_bnxs
=
tf
.
matmul
(
tokens_bnxt
,
weights
_arc
,
transpose_b
=
True
)
roots_1xbn
=
tf
.
matmul
(
root_1xs
,
weights_targets_bnxs
,
transpose_b
=
True
)
roots_1xbn
=
tf
.
matmul
(
root_1xs
,
weights_targets_bnxs
,
transpose_b
=
True
)
# Add in the score for selecting the root as a source.
roots_1xbn
+=
tf
.
matmul
(
root_1xs
,
weights_source_sx1
)
# Restore the batch dimension in the output.
# Restore the batch dimension in the output.
roots_bxn
=
tf
.
reshape
(
roots_1xbn
,
[
batch_size
,
num_tokens
])
roots_bxn
=
tf
.
reshape
(
roots_1xbn
,
[
batch_size
,
num_tokens
])
return
roots_bxn
return
roots_bxn
...
@@ -354,3 +367,110 @@ def LabelPotentialsFromTokenPairs(sources, targets, weights):
...
@@ -354,3 +367,110 @@ def LabelPotentialsFromTokenPairs(sources, targets, weights):
transpose_b
=
True
)
transpose_b
=
True
)
labels_bxnxl
=
tf
.
squeeze
(
labels_bxnxlx1
,
[
3
])
labels_bxnxl
=
tf
.
squeeze
(
labels_bxnxlx1
,
[
3
])
return
labels_bxnxl
return
labels_bxnxl
def
ValidArcAndTokenMasks
(
lengths
,
max_length
,
dtype
=
tf
.
float32
):
r
"""Returns 0/1 masks for valid arcs and tokens.
Args:
lengths: [B] vector of input sequence lengths.
max_length: Scalar maximum input sequence length, aka M.
dtype: Data type for output mask.
Returns:
[B,M,M] tensor A with 0/1 indicators of valid arcs. Specifically,
A_{b,t,s} = t,s < lengths[b] ? 1 : 0
[B,M] matrix T with 0/1 indicators of valid tokens. Specifically,
T_{b,t} = t < lengths[b] ? 1 : 0
"""
lengths_bx1
=
tf
.
expand_dims
(
lengths
,
1
)
sequence_m
=
tf
.
range
(
tf
.
cast
(
max_length
,
lengths
.
dtype
.
base_dtype
))
sequence_1xm
=
tf
.
expand_dims
(
sequence_m
,
0
)
# Create vectors of 0/1 indicators for valid tokens. Note that the comparison
# operator will broadcast from [1,M] and [B,1] to [B,M].
valid_token_bxm
=
tf
.
cast
(
sequence_1xm
<
lengths_bx1
,
dtype
)
# Compute matrices of 0/1 indicators for valid arcs as the outer product of
# the valid token indicator vector with itself.
valid_arc_bxmxm
=
tf
.
matmul
(
tf
.
expand_dims
(
valid_token_bxm
,
2
),
tf
.
expand_dims
(
valid_token_bxm
,
1
))
return
valid_arc_bxmxm
,
valid_token_bxm
def
LaplacianMatrix
(
lengths
,
arcs
,
forest
=
False
):
r
"""Returns the (root-augmented) Laplacian matrix for a batch of digraphs.
Args:
lengths: [B] vector of input sequence lengths.
arcs: [B,M,M] tensor of arc potentials where entry b,t,s is the potential of
the arc from s to t in the b'th digraph, while b,t,t is the potential of t
as a root. Entries b,t,s where t or s >= lengths[b] are ignored.
forest: Whether to produce a Laplacian for trees or forests.
Returns:
[B,M,M] tensor L with the Laplacian of each digraph, padded with an identity
matrix. More concretely, the padding entries (t or s >= lengths[b]) are:
L_{b,t,t} = 1.0
L_{b,t,s} = 0.0
Note that this "identity matrix padding" ensures that the determinant of
each padded matrix equals the determinant of the unpadded matrix. The
non-padding entries (t,s < lengths[b]) depend on whether the Laplacian is
constructed for trees or forests. For trees:
L_{b,t,0} = arcs[b,t,t]
L_{b,t,t} = \sum_{s < lengths[b], t != s} arcs[b,t,s]
L_{b,t,s} = -arcs[b,t,s]
For forests:
L_{b,t,t} = \sum_{s < lengths[b]} arcs[b,t,s]
L_{b,t,s} = -arcs[b,t,s]
See http://www.aclweb.org/anthology/D/D07/D07-1015.pdf for details, though
note that our matrices are transposed from their notation.
"""
check
.
Eq
(
arcs
.
get_shape
().
ndims
,
3
,
'arcs must be rank 3'
)
dtype
=
arcs
.
dtype
.
base_dtype
arcs_shape
=
tf
.
shape
(
arcs
)
batch_size
=
arcs_shape
[
0
]
max_length
=
arcs_shape
[
1
]
with
tf
.
control_dependencies
([
tf
.
assert_equal
(
max_length
,
arcs_shape
[
2
])]):
valid_arc_bxmxm
,
valid_token_bxm
=
ValidArcAndTokenMasks
(
lengths
,
max_length
,
dtype
=
dtype
)
invalid_token_bxm
=
tf
.
constant
(
1
,
dtype
=
dtype
)
-
valid_token_bxm
# Zero out all invalid arcs, to avoid polluting bulk summations.
arcs_bxmxm
=
arcs
*
valid_arc_bxmxm
zeros_bxm
=
tf
.
zeros
([
batch_size
,
max_length
],
dtype
)
if
not
forest
:
# For trees, extract the root potentials and exclude them from the sums
# computed below.
roots_bxm
=
tf
.
matrix_diag_part
(
arcs_bxmxm
)
# only defined for trees
arcs_bxmxm
=
tf
.
matrix_set_diag
(
arcs_bxmxm
,
zeros_bxm
)
# Sum inbound arc potentials for each target token. These sums will form
# the diagonal of the Laplacian matrix. Note that these sums are zero for
# invalid tokens, since their arc potentials were masked out above.
sums_bxm
=
tf
.
reduce_sum
(
arcs_bxmxm
,
2
)
if
forest
:
# For forests, zero out the root potentials after computing the sums above
# so we don't cancel them out when we subtract the arc potentials.
arcs_bxmxm
=
tf
.
matrix_set_diag
(
arcs_bxmxm
,
zeros_bxm
)
# The diagonal of the result is the combination of the arc sums, which are
# non-zero only on valid tokens, and the invalid token indicators, which are
# non-zero only on invalid tokens. Note that the latter form the diagonal
# of the identity matrix padding.
diagonal_bxm
=
sums_bxm
+
invalid_token_bxm
# Combine sums and negative arc potentials. Note that the off-diagonal
# padding entries will be zero thanks to the arc mask.
laplacian_bxmxm
=
tf
.
matrix_diag
(
diagonal_bxm
)
-
arcs_bxmxm
if
not
forest
:
# For trees, replace the first column with the root potentials.
roots_bxmx1
=
tf
.
expand_dims
(
roots_bxm
,
2
)
laplacian_bxmxm
=
tf
.
concat
([
roots_bxmx1
,
laplacian_bxmxm
[:,
:,
1
:]],
2
)
return
laplacian_bxmxm
research/syntaxnet/dragnn/python/digraph_ops_test.py
View file @
80178fc6
...
@@ -31,16 +31,18 @@ class DigraphOpsTest(tf.test.TestCase):
...
@@ -31,16 +31,18 @@ class DigraphOpsTest(tf.test.TestCase):
[
3
,
4
]],
[
3
,
4
]],
[[
3
,
4
],
[[
3
,
4
],
[
2
,
3
],
[
2
,
3
],
[
1
,
2
]]],
tf
.
float32
)
[
1
,
2
]]],
tf
.
float32
)
# pyformat: disable
target_tokens
=
tf
.
constant
([[[
4
,
5
,
6
],
target_tokens
=
tf
.
constant
([[[
4
,
5
,
6
],
[
5
,
6
,
7
],
[
5
,
6
,
7
],
[
6
,
7
,
8
]],
[
6
,
7
,
8
]],
[[
6
,
7
,
8
],
[[
6
,
7
,
8
],
[
5
,
6
,
7
],
[
5
,
6
,
7
],
[
4
,
5
,
6
]]],
tf
.
float32
)
[
4
,
5
,
6
]]],
tf
.
float32
)
# pyformat: disable
weights
=
tf
.
constant
([[
2
,
3
,
5
],
weights
=
tf
.
constant
([[
2
,
3
,
5
],
[
7
,
11
,
13
]],
[
7
,
11
,
13
]],
tf
.
float32
)
tf
.
float32
)
# pyformat: disable
arcs
=
digraph_ops
.
ArcPotentialsFromTokens
(
source_tokens
,
target_tokens
,
arcs
=
digraph_ops
.
ArcPotentialsFromTokens
(
source_tokens
,
target_tokens
,
weights
)
weights
)
...
@@ -54,7 +56,7 @@ class DigraphOpsTest(tf.test.TestCase):
...
@@ -54,7 +56,7 @@ class DigraphOpsTest(tf.test.TestCase):
[
803
,
957
,
1111
]],
[
803
,
957
,
1111
]],
[[
1111
,
957
,
803
],
# reflected through the center
[[
1111
,
957
,
803
],
# reflected through the center
[
815
,
702
,
589
],
[
815
,
702
,
589
],
[
519
,
447
,
375
]]])
[
519
,
447
,
375
]]])
# pyformat: disable
def
testArcSourcePotentialsFromTokens
(
self
):
def
testArcSourcePotentialsFromTokens
(
self
):
with
self
.
test_session
():
with
self
.
test_session
():
...
@@ -63,7 +65,7 @@ class DigraphOpsTest(tf.test.TestCase):
...
@@ -63,7 +65,7 @@ class DigraphOpsTest(tf.test.TestCase):
[
6
,
7
,
8
]],
[
6
,
7
,
8
]],
[[
6
,
7
,
8
],
[[
6
,
7
,
8
],
[
5
,
6
,
7
],
[
5
,
6
,
7
],
[
4
,
5
,
6
]]],
tf
.
float32
)
[
4
,
5
,
6
]]],
tf
.
float32
)
# pyformat: disable
weights
=
tf
.
constant
([
2
,
3
,
5
],
tf
.
float32
)
weights
=
tf
.
constant
([
2
,
3
,
5
],
tf
.
float32
)
arcs
=
digraph_ops
.
ArcSourcePotentialsFromTokens
(
tokens
,
weights
)
arcs
=
digraph_ops
.
ArcSourcePotentialsFromTokens
(
tokens
,
weights
)
...
@@ -73,7 +75,7 @@ class DigraphOpsTest(tf.test.TestCase):
...
@@ -73,7 +75,7 @@ class DigraphOpsTest(tf.test.TestCase):
[
73
,
73
,
73
]],
[
73
,
73
,
73
]],
[[
73
,
73
,
73
],
[[
73
,
73
,
73
],
[
63
,
63
,
63
],
[
63
,
63
,
63
],
[
53
,
53
,
53
]]])
[
53
,
53
,
53
]]])
# pyformat: disable
def
testRootPotentialsFromTokens
(
self
):
def
testRootPotentialsFromTokens
(
self
):
with
self
.
test_session
():
with
self
.
test_session
():
...
@@ -83,15 +85,17 @@ class DigraphOpsTest(tf.test.TestCase):
...
@@ -83,15 +85,17 @@ class DigraphOpsTest(tf.test.TestCase):
[
6
,
7
,
8
]],
[
6
,
7
,
8
]],
[[
6
,
7
,
8
],
[[
6
,
7
,
8
],
[
5
,
6
,
7
],
[
5
,
6
,
7
],
[
4
,
5
,
6
]]],
tf
.
float32
)
[
4
,
5
,
6
]]],
tf
.
float32
)
# pyformat: disable
weights
=
tf
.
constant
([[
2
,
3
,
5
],
weights_arc
=
tf
.
constant
([[
2
,
3
,
5
],
[
7
,
11
,
13
]],
[
7
,
11
,
13
]],
tf
.
float32
)
tf
.
float32
)
# pyformat: disable
weights_source
=
tf
.
constant
([
11
,
10
],
tf
.
float32
)
roots
=
digraph_ops
.
RootPotentialsFromTokens
(
root
,
tokens
,
weights
)
roots
=
digraph_ops
.
RootPotentialsFromTokens
(
root
,
tokens
,
weights_arc
,
weights_source
)
self
.
assertAllEqual
(
roots
.
eval
(),
[[
375
,
4
4
7
,
5
19
],
self
.
assertAllEqual
(
roots
.
eval
(),
[[
406
,
47
8
,
5
50
],
[
5
19
,
4
47
,
375
]])
[
5
50
,
47
8
,
406
]])
# pyformat: disable
def
testCombineArcAndRootPotentials
(
self
):
def
testCombineArcAndRootPotentials
(
self
):
with
self
.
test_session
():
with
self
.
test_session
():
...
@@ -100,9 +104,9 @@ class DigraphOpsTest(tf.test.TestCase):
...
@@ -100,9 +104,9 @@ class DigraphOpsTest(tf.test.TestCase):
[
3
,
4
,
5
]],
[
3
,
4
,
5
]],
[[
3
,
4
,
5
],
[[
3
,
4
,
5
],
[
2
,
3
,
4
],
[
2
,
3
,
4
],
[
1
,
2
,
3
]]],
tf
.
float32
)
[
1
,
2
,
3
]]],
tf
.
float32
)
# pyformat: disable
roots
=
tf
.
constant
([[
6
,
7
,
8
],
roots
=
tf
.
constant
([[
6
,
7
,
8
],
[
8
,
7
,
6
]],
tf
.
float32
)
[
8
,
7
,
6
]],
tf
.
float32
)
# pyformat: disable
potentials
=
digraph_ops
.
CombineArcAndRootPotentials
(
arcs
,
roots
)
potentials
=
digraph_ops
.
CombineArcAndRootPotentials
(
arcs
,
roots
)
...
@@ -111,7 +115,7 @@ class DigraphOpsTest(tf.test.TestCase):
...
@@ -111,7 +115,7 @@ class DigraphOpsTest(tf.test.TestCase):
[
3
,
4
,
8
]],
[
3
,
4
,
8
]],
[[
8
,
4
,
5
],
[[
8
,
4
,
5
],
[
2
,
7
,
4
],
[
2
,
7
,
4
],
[
1
,
2
,
6
]]])
[
1
,
2
,
6
]]])
# pyformat: disable
def
testLabelPotentialsFromTokens
(
self
):
def
testLabelPotentialsFromTokens
(
self
):
with
self
.
test_session
():
with
self
.
test_session
():
...
@@ -120,12 +124,12 @@ class DigraphOpsTest(tf.test.TestCase):
...
@@ -120,12 +124,12 @@ class DigraphOpsTest(tf.test.TestCase):
[
5
,
6
]],
[
5
,
6
]],
[[
6
,
5
],
[[
6
,
5
],
[
4
,
3
],
[
4
,
3
],
[
2
,
1
]]],
tf
.
float32
)
[
2
,
1
]]],
tf
.
float32
)
# pyformat: disable
weights
=
tf
.
constant
([[
2
,
3
],
weights
=
tf
.
constant
([[
2
,
3
],
[
5
,
7
],
[
5
,
7
],
[
11
,
13
]],
tf
.
float32
)
[
11
,
13
]],
tf
.
float32
)
# pyformat: disable
labels
=
digraph_ops
.
LabelPotentialsFromTokens
(
tokens
,
weights
)
labels
=
digraph_ops
.
LabelPotentialsFromTokens
(
tokens
,
weights
)
...
@@ -136,7 +140,7 @@ class DigraphOpsTest(tf.test.TestCase):
...
@@ -136,7 +140,7 @@ class DigraphOpsTest(tf.test.TestCase):
[
28
,
67
,
133
]],
[
28
,
67
,
133
]],
[[
27
,
65
,
131
],
[[
27
,
65
,
131
],
[
17
,
41
,
83
],
[
17
,
41
,
83
],
[
7
,
17
,
35
]]])
[
7
,
17
,
35
]]])
# pyformat: disable
def
testLabelPotentialsFromTokenPairs
(
self
):
def
testLabelPotentialsFromTokenPairs
(
self
):
with
self
.
test_session
():
with
self
.
test_session
():
...
@@ -145,13 +149,13 @@ class DigraphOpsTest(tf.test.TestCase):
...
@@ -145,13 +149,13 @@ class DigraphOpsTest(tf.test.TestCase):
[
5
,
6
]],
[
5
,
6
]],
[[
6
,
5
],
[[
6
,
5
],
[
4
,
3
],
[
4
,
3
],
[
2
,
1
]]],
tf
.
float32
)
[
2
,
1
]]],
tf
.
float32
)
# pyformat: disable
targets
=
tf
.
constant
([[[
3
,
4
],
targets
=
tf
.
constant
([[[
3
,
4
],
[
5
,
6
],
[
5
,
6
],
[
7
,
8
]],
[
7
,
8
]],
[[
8
,
7
],
[[
8
,
7
],
[
6
,
5
],
[
6
,
5
],
[
4
,
3
]]],
tf
.
float32
)
[
4
,
3
]]],
tf
.
float32
)
# pyformat: disable
weights
=
tf
.
constant
([[[
2
,
3
],
weights
=
tf
.
constant
([[[
2
,
3
],
...
@@ -159,7 +163,7 @@ class DigraphOpsTest(tf.test.TestCase):
...
@@ -159,7 +163,7 @@ class DigraphOpsTest(tf.test.TestCase):
[[
11
,
13
],
[[
11
,
13
],
[
17
,
19
]],
[
17
,
19
]],
[[
23
,
29
],
[[
23
,
29
],
[
31
,
37
]]],
tf
.
float32
)
[
31
,
37
]]],
tf
.
float32
)
# pyformat: disable
labels
=
digraph_ops
.
LabelPotentialsFromTokenPairs
(
sources
,
targets
,
labels
=
digraph_ops
.
LabelPotentialsFromTokenPairs
(
sources
,
targets
,
weights
)
weights
)
...
@@ -171,7 +175,114 @@ class DigraphOpsTest(tf.test.TestCase):
...
@@ -171,7 +175,114 @@ class DigraphOpsTest(tf.test.TestCase):
[
736
,
2531
,
5043
]],
[
736
,
2531
,
5043
]],
[[
667
,
2419
,
4857
],
[[
667
,
2419
,
4857
],
[
303
,
1115
,
2245
],
[
303
,
1115
,
2245
],
[
75
,
291
,
593
]]])
[
75
,
291
,
593
]]])
# pyformat: disable
def
testValidArcAndTokenMasks
(
self
):
with
self
.
test_session
():
lengths
=
tf
.
constant
([
1
,
2
,
3
],
tf
.
int64
)
max_length
=
4
valid_arcs
,
valid_tokens
=
digraph_ops
.
ValidArcAndTokenMasks
(
lengths
,
max_length
)
self
.
assertAllEqual
(
valid_arcs
.
eval
(),
[[[
1
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
]],
[[
1
,
1
,
0
,
0
],
[
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
]],
[[
1
,
1
,
1
,
0
],
[
1
,
1
,
1
,
0
],
[
1
,
1
,
1
,
0
],
[
0
,
0
,
0
,
0
]]])
# pyformat: disable
self
.
assertAllEqual
(
valid_tokens
.
eval
(),
[[
1
,
0
,
0
,
0
],
[
1
,
1
,
0
,
0
],
[
1
,
1
,
1
,
0
]])
# pyformat: disable
def
testLaplacianMatrixTree
(
self
):
with
self
.
test_session
():
pad
=
12345.6
arcs
=
tf
.
constant
([[[
2
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
]],
[[
2
,
3
,
pad
,
pad
],
[
5
,
7
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
]],
[[
2
,
3
,
5
,
pad
],
[
7
,
11
,
13
,
pad
],
[
17
,
19
,
23
,
pad
],
[
pad
,
pad
,
pad
,
pad
]],
[[
2
,
3
,
5
,
7
],
[
11
,
13
,
17
,
19
],
[
23
,
29
,
31
,
37
],
[
41
,
43
,
47
,
53
]]],
tf
.
float32
)
# pyformat: disable
lengths
=
tf
.
constant
([
1
,
2
,
3
,
4
],
tf
.
int64
)
laplacian
=
digraph_ops
.
LaplacianMatrix
(
lengths
,
arcs
)
self
.
assertAllEqual
(
laplacian
.
eval
(),
[[[
2
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
1
]],
[[
2
,
-
3
,
0
,
0
],
[
7
,
5
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
1
]],
[[
2
,
-
3
,
-
5
,
0
],
[
11
,
20
,
-
13
,
0
],
[
23
,
-
19
,
36
,
0
],
[
0
,
0
,
0
,
1
]],
[[
2
,
-
3
,
-
5
,
-
7
],
[
13
,
47
,
-
17
,
-
19
],
[
31
,
-
29
,
89
,
-
37
],
[
53
,
-
43
,
-
47
,
131
]]])
# pyformat: disable
def
testLaplacianMatrixForest
(
self
):
with
self
.
test_session
():
pad
=
12345.6
arcs
=
tf
.
constant
([[[
2
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
]],
[[
2
,
3
,
pad
,
pad
],
[
5
,
7
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
]],
[[
2
,
3
,
5
,
pad
],
[
7
,
11
,
13
,
pad
],
[
17
,
19
,
23
,
pad
],
[
pad
,
pad
,
pad
,
pad
]],
[[
2
,
3
,
5
,
7
],
[
11
,
13
,
17
,
19
],
[
23
,
29
,
31
,
37
],
[
41
,
43
,
47
,
53
]]],
tf
.
float32
)
# pyformat: disable
lengths
=
tf
.
constant
([
1
,
2
,
3
,
4
],
tf
.
int64
)
laplacian
=
digraph_ops
.
LaplacianMatrix
(
lengths
,
arcs
,
forest
=
True
)
self
.
assertAllEqual
(
laplacian
.
eval
(),
[[[
2
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
1
]],
[[
5
,
-
3
,
0
,
0
],
[
-
5
,
12
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
1
]],
[[
10
,
-
3
,
-
5
,
0
],
[
-
7
,
31
,
-
13
,
0
],
[
-
17
,
-
19
,
59
,
0
],
[
0
,
0
,
0
,
1
]],
[[
17
,
-
3
,
-
5
,
-
7
],
[
-
11
,
60
,
-
17
,
-
19
],
[
-
23
,
-
29
,
120
,
-
37
],
[
-
41
,
-
43
,
-
47
,
184
]]])
# pyformat: disable
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
research/syntaxnet/dragnn/python/dragnn_model_saver.py
View file @
80178fc6
...
@@ -25,13 +25,14 @@ from __future__ import absolute_import
...
@@ -25,13 +25,14 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
from
google.protobuf
import
text_format
from
google.protobuf
import
text_format
from
dragnn.protos
import
spec_pb2
from
dragnn.protos
import
spec_pb2
from
dragnn.python
import
dragnn_model_saver_lib
as
saver_lib
from
dragnn.python
import
dragnn_model_saver_lib
as
saver_lib
flags
=
tf
.
app
.
flags
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'master_spec'
,
None
,
'Path to task context with '
flags
.
DEFINE_string
(
'master_spec'
,
None
,
'Path to task context with '
...
@@ -40,10 +41,12 @@ flags.DEFINE_string('params_path', None, 'Path to trained model parameters.')
...
@@ -40,10 +41,12 @@ flags.DEFINE_string('params_path', None, 'Path to trained model parameters.')
flags
.
DEFINE_string
(
'export_path'
,
''
,
'Output path for exported servo model.'
)
flags
.
DEFINE_string
(
'export_path'
,
''
,
'Output path for exported servo model.'
)
flags
.
DEFINE_bool
(
'export_moving_averages'
,
False
,
flags
.
DEFINE_bool
(
'export_moving_averages'
,
False
,
'Whether to export the moving average parameters.'
)
'Whether to export the moving average parameters.'
)
flags
.
DEFINE_bool
(
'build_runtime_graph'
,
False
,
'Whether to build a graph for use by the runtime.'
)
def
export
(
master_spec_path
,
params_path
,
export_path
,
def
export
(
master_spec_path
,
params_path
,
export_path
,
export_moving_averages
,
export_moving_averages
):
build_runtime_graph
):
"""Restores a model and exports it in SavedModel form.
"""Restores a model and exports it in SavedModel form.
This method loads a graph specified by the spec at master_spec_path and the
This method loads a graph specified by the spec at master_spec_path and the
...
@@ -55,6 +58,7 @@ def export(master_spec_path, params_path, export_path,
...
@@ -55,6 +58,7 @@ def export(master_spec_path, params_path, export_path,
params_path: Path to the parameters file to export.
params_path: Path to the parameters file to export.
export_path: Path to export the SavedModel to.
export_path: Path to export the SavedModel to.
export_moving_averages: Whether to export the moving average parameters.
export_moving_averages: Whether to export the moving average parameters.
build_runtime_graph: Whether to build a graph for use by the runtime.
"""
"""
graph
=
tf
.
Graph
()
graph
=
tf
.
Graph
()
...
@@ -70,16 +74,16 @@ def export(master_spec_path, params_path, export_path,
...
@@ -70,16 +74,16 @@ def export(master_spec_path, params_path, export_path,
short_to_original
=
saver_lib
.
shorten_resource_paths
(
master_spec
)
short_to_original
=
saver_lib
.
shorten_resource_paths
(
master_spec
)
saver_lib
.
export_master_spec
(
master_spec
,
graph
)
saver_lib
.
export_master_spec
(
master_spec
,
graph
)
saver_lib
.
export_to_graph
(
master_spec
,
params_path
,
stripped_path
,
graph
,
saver_lib
.
export_to_graph
(
master_spec
,
params_path
,
stripped_path
,
graph
,
export_moving_averages
)
export_moving_averages
,
build_runtime_graph
)
saver_lib
.
export_assets
(
master_spec
,
short_to_original
,
stripped_path
)
saver_lib
.
export_assets
(
master_spec
,
short_to_original
,
stripped_path
)
def
main
(
unused_argv
):
def
main
(
unused_argv
):
# Run the exporter.
# Run the exporter.
export
(
FLAGS
.
master_spec
,
FLAGS
.
params_path
,
export
(
FLAGS
.
master_spec
,
FLAGS
.
params_path
,
FLAGS
.
export_path
,
FLAGS
.
export_
path
,
FLAGS
.
export_moving_averages
)
FLAGS
.
export_
moving_averages
,
FLAGS
.
build_runtime_graph
)
tf
.
logging
.
info
(
'Export complete.'
)
tf
.
logging
.
info
(
'Export complete.'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
(
main
)
research/syntaxnet/dragnn/python/dragnn_model_saver_lib.py
View file @
80178fc6
...
@@ -164,6 +164,7 @@ def export_to_graph(master_spec,
...
@@ -164,6 +164,7 @@ def export_to_graph(master_spec,
export_path
,
export_path
,
external_graph
,
external_graph
,
export_moving_averages
,
export_moving_averages
,
build_runtime_graph
,
signature_name
=
'model'
):
signature_name
=
'model'
):
"""Restores a model and exports it in SavedModel form.
"""Restores a model and exports it in SavedModel form.
...
@@ -177,6 +178,7 @@ def export_to_graph(master_spec,
...
@@ -177,6 +178,7 @@ def export_to_graph(master_spec,
export_path: Path to export the SavedModel to.
export_path: Path to export the SavedModel to.
external_graph: A tf.Graph() object to build the graph inside.
external_graph: A tf.Graph() object to build the graph inside.
export_moving_averages: Whether to export the moving average parameters.
export_moving_averages: Whether to export the moving average parameters.
build_runtime_graph: Whether to build a graph for use by the runtime.
signature_name: Name of the signature to insert.
signature_name: Name of the signature to insert.
"""
"""
tf
.
logging
.
info
(
tf
.
logging
.
info
(
...
@@ -189,7 +191,7 @@ def export_to_graph(master_spec,
...
@@ -189,7 +191,7 @@ def export_to_graph(master_spec,
hyperparam_config
.
use_moving_average
=
export_moving_averages
hyperparam_config
.
use_moving_average
=
export_moving_averages
builder
=
graph_builder
.
MasterBuilder
(
master_spec
,
hyperparam_config
)
builder
=
graph_builder
.
MasterBuilder
(
master_spec
,
hyperparam_config
)
post_restore_hook
=
builder
.
build_post_restore_hook
()
post_restore_hook
=
builder
.
build_post_restore_hook
()
annotation
=
builder
.
add_annotation
()
annotation
=
builder
.
add_annotation
(
build_runtime_graph
=
build_runtime_graph
)
builder
.
add_saver
()
builder
.
add_saver
()
# Resets session.
# Resets session.
...
...
research/syntaxnet/dragnn/python/dragnn_model_saver_lib_test.py
View file @
80178fc6
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Test for dragnn.python.dragnn_model_saver_lib."""
"""Test for dragnn.python.dragnn_model_saver_lib."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
...
@@ -26,24 +25,30 @@ import tensorflow as tf
...
@@ -26,24 +25,30 @@ import tensorflow as tf
from
google.protobuf
import
text_format
from
google.protobuf
import
text_format
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
googletest
from
tensorflow.python.platform
import
googletest
from
dragnn.protos
import
export_pb2
from
dragnn.protos
import
spec_pb2
from
dragnn.protos
import
spec_pb2
from
dragnn.python
import
dragnn_model_saver_lib
from
dragnn.python
import
dragnn_model_saver_lib
from
syntaxnet
import
sentence_pb2
from
syntaxnet
import
test_flags
FLAGS
=
tf
.
app
.
flags
.
FLAGS
_DUMMY_TEST_SENTENCE
=
"""
token {
word: "sentence" start: 0 end: 7 break_level: NO_BREAK
def
setUpModule
():
}
if
not
hasattr
(
FLAGS
,
'test_srcdir'
):
token {
FLAGS
.
test_srcdir
=
''
word: "0" start: 9 end: 9 break_level: SPACE_BREAK
if
not
hasattr
(
FLAGS
,
'test_tmpdir'
):
}
FLAGS
.
test_tmpdir
=
tf
.
test
.
get_temp_dir
()
token {
word: "." start: 10 end: 10 break_level: NO_BREAK
}
"""
class
DragnnModelSaverLibTest
(
test_util
.
TensorFlowTestCase
):
class
DragnnModelSaverLibTest
(
test_util
.
TensorFlowTestCase
):
def
LoadSpec
(
self
,
spec_path
):
def
LoadSpec
(
self
,
spec_path
):
master_spec
=
spec_pb2
.
MasterSpec
()
master_spec
=
spec_pb2
.
MasterSpec
()
root_dir
=
os
.
path
.
join
(
FLAGS
.
test_srcdir
,
root_dir
=
os
.
path
.
join
(
test_flags
.
source_root
()
,
'dragnn/python'
)
'dragnn/python'
)
with
open
(
os
.
path
.
join
(
root_dir
,
'testdata'
,
spec_path
),
'r'
)
as
fin
:
with
open
(
os
.
path
.
join
(
root_dir
,
'testdata'
,
spec_path
),
'r'
)
as
fin
:
text_format
.
Parse
(
fin
.
read
().
replace
(
'TOPDIR'
,
root_dir
),
master_spec
)
text_format
.
Parse
(
fin
.
read
().
replace
(
'TOPDIR'
,
root_dir
),
master_spec
)
...
@@ -52,7 +57,7 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
...
@@ -52,7 +57,7 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
def
CreateLocalSpec
(
self
,
spec_path
):
def
CreateLocalSpec
(
self
,
spec_path
):
master_spec
=
self
.
LoadSpec
(
spec_path
)
master_spec
=
self
.
LoadSpec
(
spec_path
)
master_spec_name
=
os
.
path
.
basename
(
spec_path
)
master_spec_name
=
os
.
path
.
basename
(
spec_path
)
outfile
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
master_spec_name
)
outfile
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
master_spec_name
)
fout
=
open
(
outfile
,
'w'
)
fout
=
open
(
outfile
,
'w'
)
fout
.
write
(
text_format
.
MessageToString
(
master_spec
))
fout
.
write
(
text_format
.
MessageToString
(
master_spec
))
return
outfile
return
outfile
...
@@ -80,16 +85,50 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
...
@@ -80,16 +85,50 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
# Return a set of all unique paths.
# Return a set of all unique paths.
return
set
(
path_list
)
return
set
(
path_list
)
def
GetHookNodeNames
(
self
,
master_spec
):
"""Returns hook node names to use in tests.
Args:
master_spec: MasterSpec proto from which to infer hook node names.
Returns:
Tuple of (averaged hook node name, non-averaged hook node name, cell
subgraph hook node name).
Raises:
ValueError: If hook nodes cannot be inferred from the |master_spec|.
"""
# Find an op name we can use for testing runtime hooks. Assume that at
# least one component has a fixed feature (else what is the model doing?).
component_name
=
None
for
component_spec
in
master_spec
.
component
:
if
component_spec
.
fixed_feature
:
component_name
=
component_spec
.
name
break
if
not
component_name
:
raise
ValueError
(
'Cannot infer hook node names'
)
non_averaged_hook_name
=
'{}/fixed_embedding_matrix_0/trimmed'
.
format
(
component_name
)
averaged_hook_name
=
'{}/ExponentialMovingAverage'
.
format
(
non_averaged_hook_name
)
cell_subgraph_hook_name
=
'{}/EXPORT/CellSubgraphSpec'
.
format
(
component_name
)
return
averaged_hook_name
,
non_averaged_hook_name
,
cell_subgraph_hook_name
def
testModelExport
(
self
):
def
testModelExport
(
self
):
# Get the master spec and params for this graph.
# Get the master spec and params for this graph.
master_spec
=
self
.
LoadSpec
(
'ud-hungarian.master-spec'
)
master_spec
=
self
.
LoadSpec
(
'ud-hungarian.master-spec'
)
params_path
=
os
.
path
.
join
(
params_path
=
os
.
path
.
join
(
FLAGS
.
test_srcdir
,
'dragnn/python/testdata'
test_flags
.
source_root
(),
'dragnn/python/testdata'
'/ud-hungarian.params'
)
'/ud-hungarian.params'
)
# Export the graph via SavedModel. (Here, we maintain a handle to the graph
# Export the graph via SavedModel. (Here, we maintain a handle to the graph
# for comparison, but that's usually not necessary.)
# for comparison, but that's usually not necessary.)
export_path
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'export'
)
export_path
=
os
.
path
.
join
(
test_flags
.
temp_dir
(),
'export'
)
dragnn_model_saver_lib
.
clean_output_paths
(
export_path
)
saver_graph
=
tf
.
Graph
()
saver_graph
=
tf
.
Graph
()
shortened_to_original
=
dragnn_model_saver_lib
.
shorten_resource_paths
(
shortened_to_original
=
dragnn_model_saver_lib
.
shorten_resource_paths
(
...
@@ -102,7 +141,8 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
...
@@ -102,7 +141,8 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
params_path
,
params_path
,
export_path
,
export_path
,
saver_graph
,
saver_graph
,
export_moving_averages
=
False
)
export_moving_averages
=
False
,
build_runtime_graph
=
False
)
# Export the assets as well.
# Export the assets as well.
dragnn_model_saver_lib
.
export_assets
(
master_spec
,
shortened_to_original
,
dragnn_model_saver_lib
.
export_assets
(
master_spec
,
shortened_to_original
,
...
@@ -126,6 +166,165 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
...
@@ -126,6 +166,165 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
tf
.
saved_model
.
loader
.
load
(
sess
,
[
tf
.
saved_model
.
tag_constants
.
SERVING
],
tf
.
saved_model
.
loader
.
load
(
sess
,
[
tf
.
saved_model
.
tag_constants
.
SERVING
],
export_path
)
export_path
)
averaged_hook_name
,
non_averaged_hook_name
,
_
=
self
.
GetHookNodeNames
(
master_spec
)
# Check that the averaged runtime hook node does not exist.
with
self
.
assertRaises
(
KeyError
):
restored_graph
.
get_operation_by_name
(
averaged_hook_name
)
# Check that the non-averaged version also does not exist.
with
self
.
assertRaises
(
KeyError
):
restored_graph
.
get_operation_by_name
(
non_averaged_hook_name
)
def
testModelExportWithAveragesAndHooks
(
self
):
# Get the master spec and params for this graph.
master_spec
=
self
.
LoadSpec
(
'ud-hungarian.master-spec'
)
params_path
=
os
.
path
.
join
(
test_flags
.
source_root
(),
'dragnn/python/testdata'
'/ud-hungarian.params'
)
# Export the graph via SavedModel. (Here, we maintain a handle to the graph
# for comparison, but that's usually not necessary.) Note that the export
# path must not already exist.
export_path
=
os
.
path
.
join
(
test_flags
.
temp_dir
(),
'export2'
)
dragnn_model_saver_lib
.
clean_output_paths
(
export_path
)
saver_graph
=
tf
.
Graph
()
shortened_to_original
=
dragnn_model_saver_lib
.
shorten_resource_paths
(
master_spec
)
dragnn_model_saver_lib
.
export_master_spec
(
master_spec
,
saver_graph
)
dragnn_model_saver_lib
.
export_to_graph
(
master_spec
,
params_path
,
export_path
,
saver_graph
,
export_moving_averages
=
True
,
build_runtime_graph
=
True
)
# Export the assets as well.
dragnn_model_saver_lib
.
export_assets
(
master_spec
,
shortened_to_original
,
export_path
)
# Validate that the assets are all in the exported directory.
path_set
=
self
.
ValidateAssetExistence
(
master_spec
,
export_path
)
# This master-spec has 4 unique assets. If there are more, we have not
# uniquified the assets properly.
self
.
assertEqual
(
len
(
path_set
),
4
)
# Restore the graph from the checkpoint into a new Graph object.
restored_graph
=
tf
.
Graph
()
restoration_config
=
tf
.
ConfigProto
(
log_device_placement
=
False
,
intra_op_parallelism_threads
=
10
,
inter_op_parallelism_threads
=
10
)
with
tf
.
Session
(
graph
=
restored_graph
,
config
=
restoration_config
)
as
sess
:
tf
.
saved_model
.
loader
.
load
(
sess
,
[
tf
.
saved_model
.
tag_constants
.
SERVING
],
export_path
)
averaged_hook_name
,
non_averaged_hook_name
,
cell_subgraph_hook_name
=
(
self
.
GetHookNodeNames
(
master_spec
))
# Check that an averaged runtime hook node exists.
restored_graph
.
get_operation_by_name
(
averaged_hook_name
)
# Check that the non-averaged version does not exist.
with
self
.
assertRaises
(
KeyError
):
restored_graph
.
get_operation_by_name
(
non_averaged_hook_name
)
# Load the cell subgraph.
cell_subgraph_bytes
=
restored_graph
.
get_tensor_by_name
(
cell_subgraph_hook_name
+
':0'
)
cell_subgraph_bytes
=
cell_subgraph_bytes
.
eval
(
feed_dict
=
{
'annotation/ComputeSession/InputBatch:0'
:
[]})
cell_subgraph_spec
=
export_pb2
.
CellSubgraphSpec
()
cell_subgraph_spec
.
ParseFromString
(
cell_subgraph_bytes
)
tf
.
logging
.
info
(
'cell_subgraph_spec = %s'
,
cell_subgraph_spec
)
# Sanity check inputs.
for
cell_input
in
cell_subgraph_spec
.
input
:
self
.
assertGreater
(
len
(
cell_input
.
name
),
0
)
self
.
assertGreater
(
len
(
cell_input
.
tensor
),
0
)
self
.
assertNotEqual
(
cell_input
.
type
,
export_pb2
.
CellSubgraphSpec
.
Input
.
TYPE_UNKNOWN
)
restored_graph
.
get_tensor_by_name
(
cell_input
.
tensor
)
# shouldn't raise
# Sanity check outputs.
for
cell_output
in
cell_subgraph_spec
.
output
:
self
.
assertGreater
(
len
(
cell_output
.
name
),
0
)
self
.
assertGreater
(
len
(
cell_output
.
tensor
),
0
)
restored_graph
.
get_tensor_by_name
(
cell_output
.
tensor
)
# shouldn't raise
# GetHookNames() finds a component with a fixed feature, so at least the
# first feature ID should exist.
self
.
assertTrue
(
any
(
cell_input
.
name
==
'fixed_channel_0_index_0_ids'
for
cell_input
in
cell_subgraph_spec
.
input
))
# Most dynamic components produce a logits layer.
self
.
assertTrue
(
any
(
cell_output
.
name
==
'logits'
for
cell_output
in
cell_subgraph_spec
.
output
))
def
testModelExportProducesRunnableModel
(
self
):
# Get the master spec and params for this graph.
master_spec
=
self
.
LoadSpec
(
'ud-hungarian.master-spec'
)
params_path
=
os
.
path
.
join
(
test_flags
.
source_root
(),
'dragnn/python/testdata'
'/ud-hungarian.params'
)
# Export the graph via SavedModel. (Here, we maintain a handle to the graph
# for comparison, but that's usually not necessary.)
export_path
=
os
.
path
.
join
(
test_flags
.
temp_dir
(),
'export'
)
dragnn_model_saver_lib
.
clean_output_paths
(
export_path
)
saver_graph
=
tf
.
Graph
()
shortened_to_original
=
dragnn_model_saver_lib
.
shorten_resource_paths
(
master_spec
)
dragnn_model_saver_lib
.
export_master_spec
(
master_spec
,
saver_graph
)
dragnn_model_saver_lib
.
export_to_graph
(
master_spec
,
params_path
,
export_path
,
saver_graph
,
export_moving_averages
=
False
,
build_runtime_graph
=
False
)
# Export the assets as well.
dragnn_model_saver_lib
.
export_assets
(
master_spec
,
shortened_to_original
,
export_path
)
# Restore the graph from the checkpoint into a new Graph object.
restored_graph
=
tf
.
Graph
()
restoration_config
=
tf
.
ConfigProto
(
log_device_placement
=
False
,
intra_op_parallelism_threads
=
10
,
inter_op_parallelism_threads
=
10
)
with
tf
.
Session
(
graph
=
restored_graph
,
config
=
restoration_config
)
as
sess
:
tf
.
saved_model
.
loader
.
load
(
sess
,
[
tf
.
saved_model
.
tag_constants
.
SERVING
],
export_path
)
test_doc
=
sentence_pb2
.
Sentence
()
text_format
.
Parse
(
_DUMMY_TEST_SENTENCE
,
test_doc
)
test_reader_string
=
test_doc
.
SerializeToString
()
test_inputs
=
[
test_reader_string
]
tf_out
=
sess
.
run
(
'annotation/annotations:0'
,
feed_dict
=
{
'annotation/ComputeSession/InputBatch:0'
:
test_inputs
})
# We don't care about accuracy, only that the run sessions don't crash.
del
tf_out
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
googletest
.
main
()
googletest
.
main
()
research/syntaxnet/dragnn/python/file_diff_test.py
0 → 100644
View file @
80178fc6
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Diff test that compares two files are identical."""
from
absl
import
flags
import
tensorflow
as
tf
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'actual_file'
,
None
,
'File to test.'
)
flags
.
DEFINE_string
(
'expected_file'
,
None
,
'File with expected contents.'
)
class
DiffTest
(
tf
.
test
.
TestCase
):
def
testEqualFiles
(
self
):
content_actual
=
None
content_expected
=
None
try
:
with
open
(
FLAGS
.
actual_file
)
as
actual
:
content_actual
=
actual
.
read
()
except
IOError
as
e
:
self
.
fail
(
"Error opening '%s': %s"
%
(
FLAGS
.
actual_file
,
e
.
strerror
))
try
:
with
open
(
FLAGS
.
expected_file
)
as
expected
:
content_expected
=
expected
.
read
()
except
IOError
as
e
:
self
.
fail
(
"Error opening '%s': %s"
%
(
FLAGS
.
expected_file
,
e
.
strerror
))
self
.
assertTrue
(
content_actual
==
content_expected
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/syntaxnet/dragnn/python/graph_builder.py
View file @
80178fc6
...
@@ -28,7 +28,7 @@ from syntaxnet.util import check
...
@@ -28,7 +28,7 @@ from syntaxnet.util import check
try
:
try
:
tf
.
NotDifferentiable
(
'ExtractFixedFeatures'
)
tf
.
NotDifferentiable
(
'ExtractFixedFeatures'
)
except
KeyError
as
e
:
except
KeyError
,
e
:
logging
.
info
(
str
(
e
))
logging
.
info
(
str
(
e
))
...
@@ -179,6 +179,8 @@ class MasterBuilder(object):
...
@@ -179,6 +179,8 @@ class MasterBuilder(object):
optimizer: handle to the tf.train Optimizer object used to train this model.
optimizer: handle to the tf.train Optimizer object used to train this model.
master_vars: dictionary of globally shared tf.Variable objects (e.g.
master_vars: dictionary of globally shared tf.Variable objects (e.g.
the global training step and learning rate.)
the global training step and learning rate.)
read_from_avg: Whether to use averaged params instead of normal params.
build_runtime_graph: Whether to build a graph for use by the runtime.
"""
"""
def
__init__
(
self
,
master_spec
,
hyperparam_config
=
None
,
pool_scope
=
'shared'
):
def
__init__
(
self
,
master_spec
,
hyperparam_config
=
None
,
pool_scope
=
'shared'
):
...
@@ -197,14 +199,15 @@ class MasterBuilder(object):
...
@@ -197,14 +199,15 @@ class MasterBuilder(object):
ValueError: if a component is not found in the registry.
ValueError: if a component is not found in the registry.
"""
"""
self
.
spec
=
master_spec
self
.
spec
=
master_spec
self
.
hyperparams
=
(
spec_pb2
.
GridPoint
()
self
.
hyperparams
=
(
if
hyperparam_config
is
None
else
hyperparam_config
)
spec_pb2
.
GridPoint
()
if
hyperparam_config
is
None
else
hyperparam_config
)
_validate_grid_point
(
self
.
hyperparams
)
_validate_grid_point
(
self
.
hyperparams
)
self
.
pool_scope
=
pool_scope
self
.
pool_scope
=
pool_scope
# Set the graph-level random seed before creating the Components so the ops
# Set the graph-level random seed before creating the Components so the ops
# they create will use this seed.
# they create will use this seed.
tf
.
set_random_seed
(
hyperparam
_config
.
seed
)
tf
.
set_random_seed
(
self
.
hyperparam
s
.
seed
)
# Construct all utility class and variables for each Component.
# Construct all utility class and variables for each Component.
self
.
components
=
[]
self
.
components
=
[]
...
@@ -219,19 +222,37 @@ class MasterBuilder(object):
...
@@ -219,19 +222,37 @@ class MasterBuilder(object):
self
.
lookup_component
[
comp
.
name
]
=
comp
self
.
lookup_component
[
comp
.
name
]
=
comp
self
.
components
.
append
(
comp
)
self
.
components
.
append
(
comp
)
# Add global step variable.
self
.
master_vars
=
{}
self
.
master_vars
=
{}
with
tf
.
variable_scope
(
'master'
,
reuse
=
False
):
with
tf
.
variable_scope
(
'master'
,
reuse
=
False
):
# Add global step variable.
self
.
master_vars
[
'step'
]
=
tf
.
get_variable
(
self
.
master_vars
[
'step'
]
=
tf
.
get_variable
(
'step'
,
[],
initializer
=
tf
.
zeros_initializer
(),
dtype
=
tf
.
int32
)
'step'
,
[],
initializer
=
tf
.
zeros_initializer
(),
dtype
=
tf
.
int32
)
self
.
master_vars
[
'learning_rate'
]
=
_create_learning_rate
(
self
.
hyperparams
,
self
.
master_vars
[
'step'
])
# Add learning rate. If the learning rate is optimized externally, then
# just create an assign op.
if
self
.
hyperparams
.
pbt_optimize_learning_rate
:
self
.
master_vars
[
'learning_rate'
]
=
tf
.
get_variable
(
'learning_rate'
,
initializer
=
tf
.
constant
(
self
.
hyperparams
.
learning_rate
,
dtype
=
tf
.
float32
))
lr_assign_input
=
tf
.
placeholder
(
tf
.
float32
,
[],
'pbt/assign/learning_rate/Value'
)
tf
.
assign
(
self
.
master_vars
[
'learning_rate'
],
value
=
lr_assign_input
,
name
=
'pbt/assign/learning_rate'
)
else
:
self
.
master_vars
[
'learning_rate'
]
=
_create_learning_rate
(
self
.
hyperparams
,
self
.
master_vars
[
'step'
])
# Construct optimizer.
# Construct optimizer.
self
.
optimizer
=
_create_optimizer
(
self
.
hyperparams
,
self
.
optimizer
=
_create_optimizer
(
self
.
hyperparams
,
self
.
master_vars
[
'learning_rate'
],
self
.
master_vars
[
'learning_rate'
],
self
.
master_vars
[
'step'
])
self
.
master_vars
[
'step'
])
self
.
read_from_avg
=
False
self
.
build_runtime_graph
=
False
@
property
@
property
def
component_names
(
self
):
def
component_names
(
self
):
return
tuple
(
c
.
name
for
c
in
self
.
components
)
return
tuple
(
c
.
name
for
c
in
self
.
components
)
...
@@ -366,8 +387,9 @@ class MasterBuilder(object):
...
@@ -366,8 +387,9 @@ class MasterBuilder(object):
max_index
=
len
(
self
.
components
)
max_index
=
len
(
self
.
components
)
else
:
else
:
if
not
0
<
max_index
<=
len
(
self
.
components
):
if
not
0
<
max_index
<=
len
(
self
.
components
):
raise
IndexError
(
'Invalid max_index {} for components {}; handle {}'
.
raise
IndexError
(
format
(
max_index
,
self
.
component_names
,
handle
.
name
))
'Invalid max_index {} for components {}; handle {}'
.
format
(
max_index
,
self
.
component_names
,
handle
.
name
))
# By default, we train every component supervised.
# By default, we train every component supervised.
if
not
component_weights
:
if
not
component_weights
:
...
@@ -375,6 +397,11 @@ class MasterBuilder(object):
...
@@ -375,6 +397,11 @@ class MasterBuilder(object):
if
not
unroll_using_oracle
:
if
not
unroll_using_oracle
:
unroll_using_oracle
=
[
True
]
*
max_index
unroll_using_oracle
=
[
True
]
*
max_index
if
not
max_index
<=
len
(
unroll_using_oracle
):
raise
IndexError
((
'Invalid max_index {} for unroll_using_oracle {}; '
'handle {}'
).
format
(
max_index
,
unroll_using_oracle
,
handle
.
name
))
component_weights
=
component_weights
[:
max_index
]
component_weights
=
component_weights
[:
max_index
]
total_weight
=
(
float
)(
sum
(
component_weights
))
total_weight
=
(
float
)(
sum
(
component_weights
))
component_weights
=
[
w
/
total_weight
for
w
in
component_weights
]
component_weights
=
[
w
/
total_weight
for
w
in
component_weights
]
...
@@ -408,10 +435,10 @@ class MasterBuilder(object):
...
@@ -408,10 +435,10 @@ class MasterBuilder(object):
args
=
(
master_state
,
network_states
)
args
=
(
master_state
,
network_states
)
if
unroll_using_oracle
[
component_index
]:
if
unroll_using_oracle
[
component_index
]:
handle
,
component_cost
,
component_correct
,
component_total
=
(
tf
.
cond
(
handle
,
component_cost
,
component_correct
,
component_total
=
(
comp
.
training_beam_size
>
1
,
tf
.
cond
(
comp
.
training_beam_size
>
1
,
lambda
:
comp
.
build_structured_training
(
*
args
),
lambda
:
comp
.
build_structured_training
(
*
args
),
lambda
:
comp
.
build_greedy_training
(
*
args
)))
lambda
:
comp
.
build_greedy_training
(
*
args
)))
else
:
else
:
handle
=
comp
.
build_greedy_inference
(
*
args
,
during_training
=
True
)
handle
=
comp
.
build_greedy_inference
(
*
args
,
during_training
=
True
)
...
@@ -445,6 +472,7 @@ class MasterBuilder(object):
...
@@ -445,6 +472,7 @@ class MasterBuilder(object):
# 1. compute the gradients,
# 1. compute the gradients,
# 2. add an optimizer to update the parameters using the gradients,
# 2. add an optimizer to update the parameters using the gradients,
# 3. make the ComputeSession handle depend on the optimizer.
# 3. make the ComputeSession handle depend on the optimizer.
gradient_norm
=
tf
.
constant
(
0.
)
if
compute_gradients
:
if
compute_gradients
:
logging
.
info
(
'Creating train op with %d variables:
\n\t
%s'
,
logging
.
info
(
'Creating train op with %d variables:
\n\t
%s'
,
len
(
params_to_train
),
len
(
params_to_train
),
...
@@ -452,8 +480,11 @@ class MasterBuilder(object):
...
@@ -452,8 +480,11 @@ class MasterBuilder(object):
grads_and_vars
=
self
.
optimizer
.
compute_gradients
(
grads_and_vars
=
self
.
optimizer
.
compute_gradients
(
cost
,
var_list
=
params_to_train
)
cost
,
var_list
=
params_to_train
)
clipped_gradients
=
[(
self
.
_clip_gradients
(
g
),
v
)
clipped_gradients
=
[
for
g
,
v
in
grads_and_vars
]
(
self
.
_clip_gradients
(
g
),
v
)
for
g
,
v
in
grads_and_vars
]
gradient_norm
=
tf
.
global_norm
(
list
(
zip
(
*
clipped_gradients
))[
0
])
minimize_op
=
self
.
optimizer
.
apply_gradients
(
minimize_op
=
self
.
optimizer
.
apply_gradients
(
clipped_gradients
,
global_step
=
self
.
master_vars
[
'step'
])
clipped_gradients
,
global_step
=
self
.
master_vars
[
'step'
])
...
@@ -474,6 +505,7 @@ class MasterBuilder(object):
...
@@ -474,6 +505,7 @@ class MasterBuilder(object):
# Returns named access to common outputs.
# Returns named access to common outputs.
outputs
=
{
outputs
=
{
'cost'
:
cost
,
'cost'
:
cost
,
'gradient_norm'
:
gradient_norm
,
'batch'
:
effective_batch
,
'batch'
:
effective_batch
,
'metrics'
:
metrics
,
'metrics'
:
metrics
,
}
}
...
@@ -520,7 +552,10 @@ class MasterBuilder(object):
...
@@ -520,7 +552,10 @@ class MasterBuilder(object):
with
tf
.
control_dependencies
(
control_ops
):
with
tf
.
control_dependencies
(
control_ops
):
return
tf
.
no_op
(
name
=
'post_restore_hook_master'
)
return
tf
.
no_op
(
name
=
'post_restore_hook_master'
)
def
build_inference
(
self
,
handle
,
use_moving_average
=
False
):
def
build_inference
(
self
,
handle
,
use_moving_average
=
False
,
build_runtime_graph
=
False
):
"""Builds an inference pipeline.
"""Builds an inference pipeline.
This always uses the whole pipeline.
This always uses the whole pipeline.
...
@@ -530,25 +565,30 @@ class MasterBuilder(object):
...
@@ -530,25 +565,30 @@ class MasterBuilder(object):
use_moving_average: Whether or not to read from the moving
use_moving_average: Whether or not to read from the moving
average variables instead of the true parameters. Note: it is not
average variables instead of the true parameters. Note: it is not
possible to make gradient updates when this is True.
possible to make gradient updates when this is True.
build_runtime_graph: Whether to build a graph for use by the runtime.
Returns:
Returns:
handle: Handle after annotation.
handle: Handle after annotation.
"""
"""
self
.
read_from_avg
=
use_moving_average
self
.
read_from_avg
=
use_moving_average
self
.
build_runtime_graph
=
build_runtime_graph
network_states
=
{}
network_states
=
{}
for
comp
in
self
.
components
:
for
comp
in
self
.
components
:
network_states
[
comp
.
name
]
=
component
.
NetworkState
()
network_states
[
comp
.
name
]
=
component
.
NetworkState
()
handle
=
dragnn_ops
.
init_component_data
(
handle
=
dragnn_ops
.
init_component_data
(
handle
,
beam_size
=
comp
.
inference_beam_size
,
component
=
comp
.
name
)
handle
,
beam_size
=
comp
.
inference_beam_size
,
component
=
comp
.
name
)
master_state
=
component
.
MasterState
(
handle
,
if
build_runtime_graph
:
dragnn_ops
.
batch_size
(
batch_size
=
1
# runtime uses singleton batches
handle
,
component
=
comp
.
name
))
else
:
batch_size
=
dragnn_ops
.
batch_size
(
handle
,
component
=
comp
.
name
)
master_state
=
component
.
MasterState
(
handle
,
batch_size
)
with
tf
.
control_dependencies
([
handle
]):
with
tf
.
control_dependencies
([
handle
]):
handle
=
comp
.
build_greedy_inference
(
master_state
,
network_states
)
handle
=
comp
.
build_greedy_inference
(
master_state
,
network_states
)
handle
=
dragnn_ops
.
write_annotations
(
handle
,
component
=
comp
.
name
)
handle
=
dragnn_ops
.
write_annotations
(
handle
,
component
=
comp
.
name
)
self
.
read_from_avg
=
False
self
.
read_from_avg
=
False
self
.
build_runtime_graph
=
False
return
handle
return
handle
def
add_training_from_config
(
self
,
def
add_training_from_config
(
self
,
...
@@ -625,7 +665,10 @@ class MasterBuilder(object):
...
@@ -625,7 +665,10 @@ class MasterBuilder(object):
return
self
.
_outputs_with_release
(
handle
,
{
'input_batch'
:
input_batch
},
return
self
.
_outputs_with_release
(
handle
,
{
'input_batch'
:
input_batch
},
outputs
)
outputs
)
def
add_annotation
(
self
,
name_scope
=
'annotation'
,
enable_tracing
=
False
):
def
add_annotation
(
self
,
name_scope
=
'annotation'
,
enable_tracing
=
False
,
build_runtime_graph
=
False
):
"""Adds an annotation pipeline to the graph.
"""Adds an annotation pipeline to the graph.
This will create the following additional named targets by default, for use
This will create the following additional named targets by default, for use
...
@@ -640,13 +683,17 @@ class MasterBuilder(object):
...
@@ -640,13 +683,17 @@ class MasterBuilder(object):
enable_tracing: Enabling this will result in two things:
enable_tracing: Enabling this will result in two things:
1. Tracing will be enabled during inference.
1. Tracing will be enabled during inference.
2. A 'traces' node will be added to the outputs.
2. A 'traces' node will be added to the outputs.
build_runtime_graph: Whether to build a graph for use by the runtime.
Returns:
Returns:
A dictionary of input and output nodes.
A dictionary of input and output nodes.
"""
"""
with
tf
.
name_scope
(
name_scope
):
with
tf
.
name_scope
(
name_scope
):
handle
,
input_batch
=
self
.
_get_session_with_reader
(
enable_tracing
)
handle
,
input_batch
=
self
.
_get_session_with_reader
(
enable_tracing
)
handle
=
self
.
build_inference
(
handle
,
use_moving_average
=
True
)
handle
=
self
.
build_inference
(
handle
,
use_moving_average
=
True
,
build_runtime_graph
=
build_runtime_graph
)
annotations
=
dragnn_ops
.
emit_annotations
(
annotations
=
dragnn_ops
.
emit_annotations
(
handle
,
component
=
self
.
spec
.
component
[
-
1
].
name
)
handle
,
component
=
self
.
spec
.
component
[
-
1
].
name
)
...
@@ -666,7 +713,7 @@ class MasterBuilder(object):
...
@@ -666,7 +713,7 @@ class MasterBuilder(object):
def
add_saver
(
self
):
def
add_saver
(
self
):
"""Adds a Saver for all variables in the graph."""
"""Adds a Saver for all variables in the graph."""
logging
.
info
(
'
Saving
variables:
\n\t
%s'
,
logging
.
info
(
'
Generating op to save
variables:
\n\t
%s'
,
'
\n\t
'
.
join
([
x
.
name
for
x
in
tf
.
global_variables
()]))
'
\n\t
'
.
join
([
x
.
name
for
x
in
tf
.
global_variables
()]))
self
.
saver
=
tf
.
train
.
Saver
(
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
[
x
for
x
in
tf
.
global_variables
()],
var_list
=
[
x
for
x
in
tf
.
global_variables
()],
...
...
research/syntaxnet/dragnn/python/graph_builder_test.py
View file @
80178fc6
...
@@ -20,7 +20,6 @@ import os.path
...
@@ -20,7 +20,6 @@ import os.path
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
xrange
import
tensorflow
as
tf
import
tensorflow
as
tf
from
google.protobuf
import
text_format
from
google.protobuf
import
text_format
...
@@ -30,13 +29,12 @@ from dragnn.protos import trace_pb2
...
@@ -30,13 +29,12 @@ from dragnn.protos import trace_pb2
from
dragnn.python
import
dragnn_ops
from
dragnn.python
import
dragnn_ops
from
dragnn.python
import
graph_builder
from
dragnn.python
import
graph_builder
from
syntaxnet
import
sentence_pb2
from
syntaxnet
import
sentence_pb2
from
syntaxnet
import
test_flags
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
googletest
from
tensorflow.python.platform
import
googletest
from
tensorflow.python.platform
import
tf_logging
as
logging
from
tensorflow.python.platform
import
tf_logging
as
logging
FLAGS
=
tf
.
app
.
flags
.
FLAGS
_DUMMY_GOLD_SENTENCE
=
"""
_DUMMY_GOLD_SENTENCE
=
"""
token {
token {
...
@@ -151,13 +149,6 @@ token {
...
@@ -151,13 +149,6 @@ token {
]
]
def
setUpModule
():
if
not
hasattr
(
FLAGS
,
'test_srcdir'
):
FLAGS
.
test_srcdir
=
''
if
not
hasattr
(
FLAGS
,
'test_tmpdir'
):
FLAGS
.
test_tmpdir
=
tf
.
test
.
get_temp_dir
()
def
_as_op
(
x
):
def
_as_op
(
x
):
"""Always returns the tf.Operation associated with a node."""
"""Always returns the tf.Operation associated with a node."""
return
x
.
op
if
isinstance
(
x
,
tf
.
Tensor
)
else
x
return
x
.
op
if
isinstance
(
x
,
tf
.
Tensor
)
else
x
...
@@ -244,7 +235,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
...
@@ -244,7 +235,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
def
LoadSpec
(
self
,
spec_path
):
def
LoadSpec
(
self
,
spec_path
):
master_spec
=
spec_pb2
.
MasterSpec
()
master_spec
=
spec_pb2
.
MasterSpec
()
testdata
=
os
.
path
.
join
(
FLAGS
.
test_srcdir
,
testdata
=
os
.
path
.
join
(
test_flags
.
source_root
()
,
'dragnn/core/testdata'
)
'dragnn/core/testdata'
)
with
open
(
os
.
path
.
join
(
testdata
,
spec_path
),
'r'
)
as
fin
:
with
open
(
os
.
path
.
join
(
testdata
,
spec_path
),
'r'
)
as
fin
:
text_format
.
Parse
(
fin
.
read
().
replace
(
'TESTDATA'
,
testdata
),
master_spec
)
text_format
.
Parse
(
fin
.
read
().
replace
(
'TESTDATA'
,
testdata
),
master_spec
)
...
@@ -445,7 +436,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
...
@@ -445,7 +436,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
self
.
assertEqual
(
expected_num_actions
,
correct_val
)
self
.
assertEqual
(
expected_num_actions
,
correct_val
)
self
.
assertEqual
(
expected_num_actions
,
total_val
)
self
.
assertEqual
(
expected_num_actions
,
total_val
)
builder
.
saver
.
save
(
sess
,
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'model'
))
builder
.
saver
.
save
(
sess
,
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'model'
))
logging
.
info
(
'Running test.'
)
logging
.
info
(
'Running test.'
)
logging
.
info
(
'Printing annotations'
)
logging
.
info
(
'Printing annotations'
)
...
...
research/syntaxnet/dragnn/python/lexicon_test.py
View file @
80178fc6
...
@@ -27,8 +27,7 @@ from dragnn.python import lexicon
...
@@ -27,8 +27,7 @@ from dragnn.python import lexicon
from
syntaxnet
import
parser_trainer
from
syntaxnet
import
parser_trainer
from
syntaxnet
import
task_spec_pb2
from
syntaxnet
import
task_spec_pb2
from
syntaxnet
import
test_flags
FLAGS
=
tf
.
app
.
flags
.
FLAGS
_EXPECTED_CONTEXT
=
r
"""
_EXPECTED_CONTEXT
=
r
"""
...
@@ -46,13 +45,6 @@ input { name: "known-word-map" Part { file_pattern: "/tmp/known-word-map" } }
...
@@ -46,13 +45,6 @@ input { name: "known-word-map" Part { file_pattern: "/tmp/known-word-map" } }
"""
"""
def
setUpModule
():
if
not
hasattr
(
FLAGS
,
'test_srcdir'
):
FLAGS
.
test_srcdir
=
''
if
not
hasattr
(
FLAGS
,
'test_tmpdir'
):
FLAGS
.
test_tmpdir
=
tf
.
test
.
get_temp_dir
()
class
LexiconTest
(
tf
.
test
.
TestCase
):
class
LexiconTest
(
tf
.
test
.
TestCase
):
def
testCreateLexiconContext
(
self
):
def
testCreateLexiconContext
(
self
):
...
@@ -62,8 +54,8 @@ class LexiconTest(tf.test.TestCase):
...
@@ -62,8 +54,8 @@ class LexiconTest(tf.test.TestCase):
lexicon
.
create_lexicon_context
(
'/tmp'
),
expected_context
)
lexicon
.
create_lexicon_context
(
'/tmp'
),
expected_context
)
def
testBuildLexicon
(
self
):
def
testBuildLexicon
(
self
):
empty_input_path
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'empty-input'
)
empty_input_path
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'empty-input'
)
lexicon_output_path
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'lexicon-output'
)
lexicon_output_path
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'lexicon-output'
)
with
open
(
empty_input_path
,
'w'
):
with
open
(
empty_input_path
,
'w'
):
pass
pass
...
...
research/syntaxnet/dragnn/python/load_mst_cc_impl.py
0 → 100644
View file @
80178fc6
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads mst_ops shared library."""
import
os.path
import
tensorflow
as
tf
tf
.
load_op_library
(
os
.
path
.
join
(
tf
.
resource_loader
.
get_data_files_path
(),
'mst_cc_impl.so'
))
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment