Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
80178fc6
Unverified
Commit
80178fc6
authored
May 11, 2018
by
Mark Omernick
Committed by
GitHub
May 11, 2018
Browse files
Merge pull request #4153 from terryykoo/master
Export @195097388.
parents
a84e1ef9
edea2b67
Changes
145
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1716 additions
and
386 deletions
+1716
-386
research/syntaxnet/dragnn/protos/runtime.proto
research/syntaxnet/dragnn/protos/runtime.proto
+1
-1
research/syntaxnet/dragnn/protos/spec.proto
research/syntaxnet/dragnn/protos/spec.proto
+23
-3
research/syntaxnet/dragnn/protos/trace.proto
research/syntaxnet/dragnn/protos/trace.proto
+2
-0
research/syntaxnet/dragnn/python/BUILD
research/syntaxnet/dragnn/python/BUILD
+160
-26
research/syntaxnet/dragnn/python/biaffine_units.py
research/syntaxnet/dragnn/python/biaffine_units.py
+36
-11
research/syntaxnet/dragnn/python/biaffine_units_test.py
research/syntaxnet/dragnn/python/biaffine_units_test.py
+151
-0
research/syntaxnet/dragnn/python/bulk_component.py
research/syntaxnet/dragnn/python/bulk_component.py
+114
-35
research/syntaxnet/dragnn/python/bulk_component_test.py
research/syntaxnet/dragnn/python/bulk_component_test.py
+222
-165
research/syntaxnet/dragnn/python/component.py
research/syntaxnet/dragnn/python/component.py
+216
-43
research/syntaxnet/dragnn/python/component_test.py
research/syntaxnet/dragnn/python/component_test.py
+153
-0
research/syntaxnet/dragnn/python/digraph_ops.py
research/syntaxnet/dragnn/python/digraph_ops.py
+132
-12
research/syntaxnet/dragnn/python/digraph_ops_test.py
research/syntaxnet/dragnn/python/digraph_ops_test.py
+134
-23
research/syntaxnet/dragnn/python/dragnn_model_saver.py
research/syntaxnet/dragnn/python/dragnn_model_saver.py
+11
-7
research/syntaxnet/dragnn/python/dragnn_model_saver_lib.py
research/syntaxnet/dragnn/python/dragnn_model_saver_lib.py
+3
-1
research/syntaxnet/dragnn/python/dragnn_model_saver_lib_test.py
...ch/syntaxnet/dragnn/python/dragnn_model_saver_lib_test.py
+213
-14
research/syntaxnet/dragnn/python/file_diff_test.py
research/syntaxnet/dragnn/python/file_diff_test.py
+48
-0
research/syntaxnet/dragnn/python/graph_builder.py
research/syntaxnet/dragnn/python/graph_builder.py
+69
-22
research/syntaxnet/dragnn/python/graph_builder_test.py
research/syntaxnet/dragnn/python/graph_builder_test.py
+3
-12
research/syntaxnet/dragnn/python/lexicon_test.py
research/syntaxnet/dragnn/python/lexicon_test.py
+3
-11
research/syntaxnet/dragnn/python/load_mst_cc_impl.py
research/syntaxnet/dragnn/python/load_mst_cc_impl.py
+22
-0
No files found.
research/syntaxnet/dragnn/protos/runtime.proto
View file @
80178fc6
...
...
@@ -16,7 +16,7 @@ message MasterPerformanceSettings {
// Maximum size of the free list in the SessionStatePool. NB: The default
// value may occasionally change.
optional
uint64
session_state_pool_max_free_states
=
1
[
default
=
4
];
optional
uint64
session_state_pool_max_free_states
=
1
[
default
=
16
];
}
// As above, but for component-specific performance tuning settings.
...
...
research/syntaxnet/dragnn/protos/spec.proto
View file @
80178fc6
// DRAGNN Configuration proto. See go/dragnn-design for more information.
// DRAGNN Configuration proto.
syntax
=
"proto2"
;
...
...
@@ -93,7 +94,7 @@ message Part {
// are extracted, embedded, and then concatenated together as a group.
// Specification for a feature channel that is a *fixed* function of the input.
// NEXT_ID: 1
0
// NEXT_ID: 1
2
message
FixedFeatureChannel
{
// Interpretable name for this feature channel. NN builders might depend on
// this to determine how to hook different channels up internally.
...
...
@@ -129,6 +130,19 @@ message FixedFeatureChannel {
// Vocab file, containing all vocabulary words one per line.
optional
Resource
vocab
=
8
;
// Settings for feature ID dropout:
// If non-negative, enables feature ID dropout, and dropped feature IDs will
// be replaced with this ID.
optional
int64
dropout_id
=
10
[
default
=
-
1
];
// Probability of keeping each of the |vocabulary_size| feature IDs. Only
// used if |dropout_id| is non-negative, and must not be empty in that case.
// If this has fewer than |vocabulary_size| values, then the final value is
// tiled onto the remaining IDs. For example, specifying a single value is
// equivalent to setting all IDs to that value.
repeated
float
dropout_keep_probability
=
11
[
packed
=
true
];
}
// Specification for a feature channel that *links* to component
...
...
@@ -173,11 +187,17 @@ message TrainingGridSpec {
}
// A hyperparameter configuration for a training run.
// NEXT ID: 2
2
// NEXT ID: 2
3
message
GridPoint
{
// Global learning rate initialization point.
optional
double
learning_rate
=
1
[
default
=
0.1
];
// Whether to use PBT (population-based training) to optimize the learning
// rate. Population-based training is not currently open-source, so this will
// just create a tf.assign op which external frameworks can use to adjust the
// learning rate.
optional
bool
pbt_optimize_learning_rate
=
22
[
default
=
false
];
// Momentum coefficient when using MomentumOptimizer.
optional
double
momentum
=
2
[
default
=
0.9
];
...
...
research/syntaxnet/dragnn/protos/trace.proto
View file @
80178fc6
...
...
@@ -53,6 +53,8 @@ message ComponentStepTrace {
// Set to true once the step is finished. (This allows us to open a step after
// each transition, without having to know if it will be used.)
optional
bool
step_finished
=
6
[
default
=
false
];
extensions
1000
to
max
;
}
// The traces for all steps for a single Component.
...
...
research/syntaxnet/dragnn/python/BUILD
View file @
80178fc6
...
...
@@ -16,6 +16,17 @@ cc_binary(
],
)
cc_binary
(
name
=
"mst_cc_impl.so"
,
linkopts
=
select
({
"//conditions:default"
:
[
"-lm"
],
"@org_tensorflow//tensorflow:darwin"
:
[],
}),
linkshared
=
1
,
linkstatic
=
1
,
deps
=
[
"//dragnn/mst:mst_ops_cc"
],
)
filegroup
(
name
=
"testdata"
,
data
=
glob
([
"testdata/**"
]),
...
...
@@ -27,6 +38,12 @@ py_library(
data
=
[
":dragnn_cc_impl.so"
],
)
py_library
(
name
=
"load_mst_cc_impl_py"
,
srcs
=
[
"load_mst_cc_impl.py"
],
data
=
[
":mst_cc_impl.so"
],
)
py_library
(
name
=
"bulk_component"
,
srcs
=
[
...
...
@@ -50,6 +67,8 @@ py_library(
":bulk_component"
,
":dragnn_ops"
,
":network_units"
,
":runtime_support"
,
"//dragnn/protos:export_pb2_py"
,
"//syntaxnet/util:check"
,
"//syntaxnet/util:pyregistry"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
...
...
@@ -85,9 +104,9 @@ py_library(
":graph_builder"
,
":load_dragnn_cc_impl_py"
,
":network_units"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:sentence_p
y_pb2
"
,
"//syntaxnet:sentence_p
b2_py
"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
...
...
@@ -99,7 +118,9 @@ py_test(
data
=
[
":testdata"
],
deps
=
[
":dragnn_model_saver_lib"
,
"//dragnn/protos:spec_py_pb2"
,
"//dragnn/protos:export_pb2_py"
,
"//dragnn/protos:spec_pb2_py"
,
"//syntaxnet:test_flags"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
...
...
@@ -110,7 +131,9 @@ py_binary(
deps
=
[
":dragnn_model_saver_lib"
,
":spec_builder"
,
"//dragnn/protos:spec_py_pb2"
,
"//dragnn/protos:spec_pb2_py"
,
"@absl_py//absl:app"
,
"@absl_py//absl/flags"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
...
...
@@ -127,7 +150,7 @@ py_library(
":network_units"
,
":transformer_units"
,
":wrapped_units"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
...
...
@@ -159,7 +182,7 @@ py_test(
srcs
=
[
"render_parse_tree_graphviz_test.py"
],
deps
=
[
":render_parse_tree_graphviz"
,
"//syntaxnet:sentence_p
y_pb2
"
,
"//syntaxnet:sentence_p
b2_py
"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
...
...
@@ -168,7 +191,7 @@ py_library(
name
=
"render_spec_with_graphviz"
,
srcs
=
[
"render_spec_with_graphviz.py"
],
deps
=
[
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
],
)
...
...
@@ -197,7 +220,7 @@ py_binary(
"//dragnn/viz:viz-min-js-gz"
,
],
deps
=
[
"//dragnn/protos:trace_p
y_pb2
"
,
"//dragnn/protos:trace_p
b2_py
"
,
],
)
...
...
@@ -206,8 +229,8 @@ py_test(
srcs
=
[
"visualization_test.py"
],
deps
=
[
":visualization"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:trace_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//dragnn/protos:trace_p
b2_py
"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
...
...
@@ -225,6 +248,18 @@ py_library(
# Tests
py_test
(
name
=
"component_test"
,
srcs
=
[
"component_test.py"
,
],
deps
=
[
":components"
,
"//dragnn/protos:spec_pb2_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_test
(
name
=
"bulk_component_test"
,
srcs
=
[
...
...
@@ -235,9 +270,9 @@ py_test(
":components"
,
":dragnn_ops"
,
":network_units"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:sentence_p
y_pb2
"
,
"//syntaxnet:sentence_p
b2_py
"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
...
...
@@ -270,10 +305,11 @@ py_test(
deps
=
[
":dragnn_ops"
,
":graph_builder"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:trace_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//dragnn/protos:trace_p
b2_py
"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:sentence_py_pb2"
,
"//syntaxnet:sentence_pb2_py"
,
"//syntaxnet:test_flags"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
...
...
@@ -287,7 +323,7 @@ py_test(
":network_units"
,
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_ops"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:load_parser_ops_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
...
...
@@ -303,7 +339,8 @@ py_test(
":sentence_io"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:sentence_py_pb2"
,
"//syntaxnet:sentence_pb2_py"
,
"//syntaxnet:test_flags"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
...
...
@@ -313,21 +350,31 @@ py_library(
name
=
"trainer_lib"
,
srcs
=
[
"trainer_lib.py"
],
deps
=
[
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:sentence_py_pb2"
,
"//syntaxnet:task_spec_py_pb2"
,
"//syntaxnet:sentence_pb2_py"
,
"//syntaxnet:task_spec_pb2_py"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
)
py_test
(
name
=
"trainer_lib_test"
,
srcs
=
[
"trainer_lib_test.py"
],
deps
=
[
":trainer_lib"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_library
(
name
=
"lexicon"
,
srcs
=
[
"lexicon.py"
],
deps
=
[
"//syntaxnet:parser_ops"
,
"//syntaxnet:task_spec_p
y_pb2
"
,
"//syntaxnet:task_spec_p
b2_py
"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
...
...
@@ -340,6 +387,7 @@ py_test(
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:parser_trainer"
,
"//syntaxnet:test_flags"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
...
...
@@ -348,7 +396,7 @@ py_library(
name
=
"evaluation"
,
srcs
=
[
"evaluation.py"
],
deps
=
[
"//syntaxnet:sentence_p
y_pb2
"
,
"//syntaxnet:sentence_p
b2_py
"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
...
...
@@ -359,7 +407,7 @@ py_test(
srcs
=
[
"evaluation_test.py"
],
deps
=
[
":evaluation"
,
"//syntaxnet:sentence_p
y_pb2
"
,
"//syntaxnet:sentence_p
b2_py
"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
...
...
@@ -369,7 +417,7 @@ py_library(
srcs
=
[
"spec_builder.py"
],
deps
=
[
":lexicon"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
...
...
@@ -381,7 +429,7 @@ py_test(
srcs
=
[
"spec_builder_test.py"
],
deps
=
[
":spec_builder"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:load_parser_ops_py"
,
"//syntaxnet:parser_ops"
,
"//syntaxnet:parser_trainer"
,
...
...
@@ -418,6 +466,17 @@ py_library(
],
)
py_test
(
name
=
"biaffine_units_test"
,
srcs
=
[
"biaffine_units_test.py"
],
deps
=
[
":biaffine_units"
,
":network_units"
,
"//dragnn/protos:spec_pb2_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_library
(
name
=
"transformer_units"
,
srcs
=
[
"transformer_units.py"
],
...
...
@@ -437,10 +496,85 @@ py_test(
":transformer_units"
,
"//dragnn/core:dragnn_bulk_ops"
,
"//dragnn/core:dragnn_ops"
,
"//dragnn/protos:spec_p
y_pb2
"
,
"//dragnn/protos:spec_p
b2_py
"
,
"//syntaxnet:load_parser_ops_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
"@org_tensorflow//tensorflow/core:protos_all_py"
,
],
)
py_library
(
name
=
"runtime_support"
,
srcs
=
[
"runtime_support.py"
],
deps
=
[
":network_units"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_test
(
name
=
"runtime_support_test"
,
srcs
=
[
"runtime_support_test.py"
],
deps
=
[
":network_units"
,
":runtime_support"
,
"//dragnn/protos:export_pb2_py"
,
"//dragnn/protos:spec_pb2_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_library
(
name
=
"file_diff_test"
,
srcs
=
[
"file_diff_test.py"
],
deps
=
[
"@absl_py//absl/flags"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_library
(
name
=
"mst_ops"
,
srcs
=
[
"mst_ops.py"
],
visibility
=
[
"//visibility:public"
],
deps
=
[
":digraph_ops"
,
":load_mst_cc_impl_py"
,
"//dragnn/mst:mst_ops"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_test
(
name
=
"mst_ops_test"
,
srcs
=
[
"mst_ops_test.py"
],
deps
=
[
":mst_ops"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_library
(
name
=
"mst_units"
,
srcs
=
[
"mst_units.py"
],
deps
=
[
":mst_ops"
,
":network_units"
,
"//syntaxnet/util:check"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
py_test
(
name
=
"mst_units_test"
,
size
=
"small"
,
srcs
=
[
"mst_units_test.py"
],
deps
=
[
":mst_units"
,
":network_units"
,
"//dragnn/protos:spec_pb2_py"
,
"@org_tensorflow//tensorflow:tensorflow_py"
,
],
)
research/syntaxnet/dragnn/python/biaffine_units.py
View file @
80178fc6
...
...
@@ -79,24 +79,44 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
self
.
_source_dim
=
self
.
_linked_feature_dims
[
'sources'
]
self
.
_target_dim
=
self
.
_linked_feature_dims
[
'targets'
]
# TODO(googleuser): Make parameter initialization configurable.
self
.
_weights
=
[]
self
.
_weights
.
append
(
tf
.
get_variable
(
'weights_arc'
,
[
self
.
_source_dim
,
self
.
_target_dim
],
tf
.
float32
,
tf
.
random_norm
al_initializer
(
stddev
=
1e-4
)))
self
.
_weights
.
append
(
tf
.
get_variable
(
'weights_source'
,
[
self
.
_source_dim
],
tf
.
float32
,
tf
.
random_normal
_initializer
(
stddev
=
1e-4
)))
self
.
_weights
.
append
(
tf
.
get_variable
(
'root'
,
[
self
.
_source_dim
],
tf
.
float32
,
tf
.
random_normal
_initializer
(
stddev
=
1e-4
)))
self
.
_weights
.
append
(
tf
.
get_variable
(
'weights_arc'
,
[
self
.
_source_dim
,
self
.
_target_dim
],
tf
.
float32
,
tf
.
orthogon
al_initializer
()))
self
.
_weights
.
append
(
tf
.
get_variable
(
'weights_source'
,
[
self
.
_source_dim
],
tf
.
float32
,
tf
.
zeros
_initializer
()))
self
.
_weights
.
append
(
tf
.
get_variable
(
'root'
,
[
self
.
_source_dim
],
tf
.
float32
,
tf
.
zeros
_initializer
()))
self
.
_params
.
extend
(
self
.
_weights
)
self
.
_regularized_weights
.
extend
(
self
.
_weights
)
# Add runtime hooks for pre-computed weights.
self
.
_derived_params
.
append
(
self
.
_get_root_weights
)
self
.
_derived_params
.
append
(
self
.
_get_root_bias
)
# Negative Layer.dim indicates that the dimension is dynamic.
self
.
_layers
.
append
(
network_units
.
Layer
(
component
,
'adjacency'
,
-
1
))
def
_get_root_weights
(
self
):
"""Pre-computes the product of the root embedding and arc weights."""
weights_arc
=
self
.
_component
.
get_variable
(
'weights_arc'
)
root
=
self
.
_component
.
get_variable
(
'root'
)
name
=
self
.
_component
.
name
+
'/root_weights'
with
tf
.
name_scope
(
None
):
return
tf
.
matmul
(
tf
.
expand_dims
(
root
,
0
),
weights_arc
,
name
=
name
)
def
_get_root_bias
(
self
):
"""Pre-computes the product of the root embedding and source weights."""
weights_source
=
self
.
_component
.
get_variable
(
'weights_source'
)
root
=
self
.
_component
.
get_variable
(
'root'
)
name
=
self
.
_component
.
name
+
'/root_bias'
with
tf
.
name_scope
(
None
):
return
tf
.
matmul
(
tf
.
expand_dims
(
root
,
0
),
tf
.
expand_dims
(
weights_source
,
1
),
name
=
name
)
def
create
(
self
,
fixed_embeddings
,
linked_embeddings
,
...
...
@@ -133,12 +153,17 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
sources_bxnxn
=
digraph_ops
.
ArcSourcePotentialsFromTokens
(
source_tokens_bxnxs
,
weights_source
)
roots_bxn
=
digraph_ops
.
RootPotentialsFromTokens
(
root
,
target_tokens_bxnxt
,
weights_arc
)
root
,
target_tokens_bxnxt
,
weights_arc
,
weights_source
)
# Combine them into a single matrix with the roots on the diagonal.
adjacency_bxnxn
=
digraph_ops
.
CombineArcAndRootPotentials
(
arcs_bxnxn
+
sources_bxnxn
,
roots_bxn
)
# The adjacency matrix currently has sources on rows and targets on columns,
# but we want targets on rows so that maximizing within a row corresponds to
# selecting sources for a given target.
adjacency_bxnxn
=
tf
.
matrix_transpose
(
adjacency_bxnxn
)
return
[
tf
.
reshape
(
adjacency_bxnxn
,
[
-
1
,
num_tokens
])]
...
...
research/syntaxnet/dragnn/python/biaffine_units_test.py
0 → 100644
View file @
80178fc6
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for biaffine_units."""
import
tensorflow
as
tf
from
google.protobuf
import
text_format
from
dragnn.protos
import
spec_pb2
from
dragnn.python
import
biaffine_units
from
dragnn.python
import
network_units
_BATCH_SIZE
=
11
_NUM_TOKENS
=
22
_TOKEN_DIM
=
33
class
MockNetwork
(
object
):
def
__init__
(
self
):
pass
def
get_layer_size
(
self
,
unused_name
):
return
_TOKEN_DIM
class
MockComponent
(
object
):
def
__init__
(
self
,
master
,
component_spec
):
self
.
master
=
master
self
.
spec
=
component_spec
self
.
name
=
component_spec
.
name
self
.
network
=
MockNetwork
()
self
.
beam_size
=
1
self
.
num_actions
=
45
self
.
_attrs
=
{}
def
attr
(
self
,
name
):
return
self
.
_attrs
[
name
]
def
get_variable
(
self
,
name
):
return
tf
.
get_variable
(
name
)
class
MockMaster
(
object
):
def
__init__
(
self
):
self
.
spec
=
spec_pb2
.
MasterSpec
()
self
.
hyperparams
=
spec_pb2
.
GridPoint
()
self
.
lookup_component
=
{
'previous'
:
MockComponent
(
self
,
spec_pb2
.
ComponentSpec
())
}
def
_make_biaffine_spec
():
"""Returns a ComponentSpec that the BiaffineDigraphNetwork works on."""
component_spec
=
spec_pb2
.
ComponentSpec
()
text_format
.
Parse
(
"""
name: "test_component"
backend { registered_name: "TestComponent" }
linked_feature {
name: "sources"
fml: "input.focus"
source_translator: "identity"
source_component: "previous"
source_layer: "sources"
size: 1
embedding_dim: -1
}
linked_feature {
name: "targets"
fml: "input.focus"
source_translator: "identity"
source_component: "previous"
source_layer: "targets"
size: 1
embedding_dim: -1
}
network_unit {
registered_name: "biaffine_units.BiaffineDigraphNetwork"
}
"""
,
component_spec
)
return
component_spec
class
BiaffineDigraphNetworkTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf
.
reset_default_graph
()
def
testCanCreate
(
self
):
"""Tests that create() works on a good spec."""
with
tf
.
Graph
().
as_default
(),
self
.
test_session
():
master
=
MockMaster
()
component
=
MockComponent
(
master
,
_make_biaffine_spec
())
with
tf
.
variable_scope
(
component
.
name
,
reuse
=
None
):
component
.
network
=
biaffine_units
.
BiaffineDigraphNetwork
(
component
)
with
tf
.
variable_scope
(
component
.
name
,
reuse
=
True
):
sources
=
network_units
.
NamedTensor
(
tf
.
zeros
([
_BATCH_SIZE
*
_NUM_TOKENS
,
_TOKEN_DIM
]),
'sources'
)
targets
=
network_units
.
NamedTensor
(
tf
.
zeros
([
_BATCH_SIZE
*
_NUM_TOKENS
,
_TOKEN_DIM
]),
'targets'
)
# No assertions on the result, just don't crash.
component
.
network
.
create
(
fixed_embeddings
=
[],
linked_embeddings
=
[
sources
,
targets
],
context_tensor_arrays
=
None
,
attention_tensor
=
None
,
during_training
=
True
,
stride
=
_BATCH_SIZE
)
def
testDerivedParametersForRuntime
(
self
):
"""Test generation of derived parameters for the runtime."""
with
tf
.
Graph
().
as_default
(),
self
.
test_session
():
master
=
MockMaster
()
component
=
MockComponent
(
master
,
_make_biaffine_spec
())
with
tf
.
variable_scope
(
component
.
name
,
reuse
=
None
):
component
.
network
=
biaffine_units
.
BiaffineDigraphNetwork
(
component
)
with
tf
.
variable_scope
(
component
.
name
,
reuse
=
True
):
self
.
assertEqual
(
len
(
component
.
network
.
derived_params
),
2
)
root_weights
=
component
.
network
.
derived_params
[
0
]()
root_bias
=
component
.
network
.
derived_params
[
1
]()
# Only check shape; values are random due to initialization.
self
.
assertAllEqual
(
root_weights
.
shape
.
as_list
(),
[
1
,
_TOKEN_DIM
])
self
.
assertAllEqual
(
root_bias
.
shape
.
as_list
(),
[
1
,
1
])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/syntaxnet/dragnn/python/bulk_component.py
View file @
80178fc6
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Component builders for non-recurrent networks in DRAGNN."""
...
...
@@ -51,10 +50,8 @@ def fetch_linked_embedding(comp, network_states, feature_spec):
feature_spec
.
name
)
source
=
comp
.
master
.
lookup_component
[
feature_spec
.
source_component
]
return
network_units
.
NamedTensor
(
network_states
[
source
.
name
].
activations
[
feature_spec
.
source_layer
].
bulk_tensor
,
feature_spec
.
name
)
return
network_units
.
NamedTensor
(
network_states
[
source
.
name
].
activations
[
feature_spec
.
source_layer
].
bulk_tensor
,
feature_spec
.
name
)
def
_validate_embedded_fixed_features
(
comp
):
...
...
@@ -63,17 +60,20 @@ def _validate_embedded_fixed_features(comp):
check
.
Gt
(
feature
.
embedding_dim
,
0
,
'Embeddings requested for non-embedded feature: %s'
%
feature
)
if
feature
.
is_constant
:
check
.
IsTrue
(
feature
.
HasField
(
'pretrained_embedding_matrix'
),
'Constant embeddings must be pretrained: %s'
%
feature
)
check
.
IsTrue
(
feature
.
HasField
(
'pretrained_embedding_matrix'
),
'Constant embeddings must be pretrained: %s'
%
feature
)
def
fetch_differentiable_fixed_embeddings
(
comp
,
state
,
stride
):
def
fetch_differentiable_fixed_embeddings
(
comp
,
state
,
stride
,
during_training
):
"""Looks up fixed features with separate, differentiable, embedding lookup.
Args:
comp: Component whose fixed features we wish to look up.
state: live MasterState object for the component.
stride: Tensor containing current batch * beam size.
during_training: True if this is being called from a training code path.
This controls, e.g., the use of feature ID dropout.
Returns:
state handle: updated state handle to be used after this call
...
...
@@ -93,6 +93,11 @@ def fetch_differentiable_fixed_embeddings(comp, state, stride):
'differentiable'
)
tf
.
logging
.
info
(
'[%s] Adding %s fixed feature "%s"'
,
comp
.
name
,
differentiable_or_constant
,
feature_spec
.
name
)
if
during_training
and
feature_spec
.
dropout_id
>=
0
:
ids
[
channel
],
weights
[
channel
]
=
network_units
.
apply_feature_id_dropout
(
ids
[
channel
],
weights
[
channel
],
feature_spec
)
size
=
stride
*
num_steps
*
feature_spec
.
size
fixed_embedding
=
network_units
.
embedding_lookup
(
comp
.
get_variable
(
network_units
.
fixed_embeddings_name
(
channel
)),
...
...
@@ -105,16 +110,22 @@ def fetch_differentiable_fixed_embeddings(comp, state, stride):
return
state
.
handle
,
fixed_embeddings
def
fetch_fast_fixed_embeddings
(
comp
,
state
):
def
fetch_fast_fixed_embeddings
(
comp
,
state
,
pad_to_batch
=
None
,
pad_to_steps
=
None
):
"""Looks up fixed features with fast, non-differentiable, op.
Since BulkFixedEmbeddings is non-differentiable with respect to the
embeddings, the idea is to call this function only when the graph is
not being used for training.
not being used for training. If the function is being called with fixed step
and batch sizes, it will use the most efficient possible extractor.
Args:
comp: Component whose fixed features we wish to look up.
state: live MasterState object for the component.
pad_to_batch: Optional; the number of batch elements to pad to.
pad_to_steps: Optional; the number of steps to pad to.
Returns:
state handle: updated state handle to be used after this call
...
...
@@ -126,19 +137,50 @@ def fetch_fast_fixed_embeddings(comp, state):
return
state
.
handle
,
[]
tf
.
logging
.
info
(
'[%s] Adding %d fast fixed features'
,
comp
.
name
,
num_channels
)
state
.
handle
,
bulk_embeddings
,
_
=
dragnn_ops
.
bulk_fixed_embeddings
(
state
.
handle
,
[
comp
.
get_variable
(
network_units
.
fixed_embeddings_name
(
c
))
for
c
in
range
(
num_channels
)
],
component
=
comp
.
name
)
bulk_embeddings
=
network_units
.
NamedTensor
(
bulk_embeddings
,
'bulk-%s-fixed-features'
%
comp
.
name
)
features
=
[
comp
.
get_variable
(
network_units
.
fixed_embeddings_name
(
c
))
for
c
in
range
(
num_channels
)
]
if
pad_to_batch
is
not
None
and
pad_to_steps
is
not
None
:
# If we have fixed padding numbers, we can use 'bulk_embed_fixed_features',
# which is the fastest embedding extractor.
state
.
handle
,
bulk_embeddings
,
_
=
dragnn_ops
.
bulk_embed_fixed_features
(
state
.
handle
,
features
,
component
=
comp
.
name
,
pad_to_batch
=
pad_to_batch
,
pad_to_steps
=
pad_to_steps
)
else
:
state
.
handle
,
bulk_embeddings
,
_
=
dragnn_ops
.
bulk_fixed_embeddings
(
state
.
handle
,
features
,
component
=
comp
.
name
)
bulk_embeddings
=
network_units
.
NamedTensor
(
bulk_embeddings
,
'bulk-%s-fixed-features'
%
comp
.
name
)
return
state
.
handle
,
[
bulk_embeddings
]
def
fetch_dense_ragged_embeddings
(
comp
,
state
):
"""Gets embeddings in RaggedTensor format."""
_validate_embedded_fixed_features
(
comp
)
num_channels
=
len
(
comp
.
spec
.
fixed_feature
)
if
not
num_channels
:
return
state
.
handle
,
[]
tf
.
logging
.
info
(
'[%s] Adding %d fast fixed features'
,
comp
.
name
,
num_channels
)
features
=
[
comp
.
get_variable
(
network_units
.
fixed_embeddings_name
(
c
))
for
c
in
range
(
num_channels
)
]
state
.
handle
,
data
,
offsets
=
dragnn_ops
.
bulk_embed_dense_fixed_features
(
state
.
handle
,
features
,
component
=
comp
.
name
)
data
=
network_units
.
NamedTensor
(
data
,
'dense-%s-data'
%
comp
.
name
)
offsets
=
network_units
.
NamedTensor
(
offsets
,
'dense-%s-offsets'
%
comp
.
name
)
return
state
.
handle
,
[
data
,
offsets
]
def
extract_fixed_feature_ids
(
comp
,
state
,
stride
):
"""Extracts fixed feature IDs.
...
...
@@ -194,8 +236,10 @@ def update_network_states(comp, tensors, network_states, stride):
with
tf
.
name_scope
(
comp
.
name
+
'/stored_act'
):
for
index
,
network_tensor
in
enumerate
(
tensors
):
network_state
.
activations
[
comp
.
network
.
layers
[
index
].
name
]
=
(
network_units
.
StoredActivations
(
tensor
=
network_tensor
,
stride
=
stride
,
dim
=
comp
.
network
.
layers
[
index
].
dim
))
network_units
.
StoredActivations
(
tensor
=
network_tensor
,
stride
=
stride
,
dim
=
comp
.
network
.
layers
[
index
].
dim
))
def
build_cross_entropy_loss
(
logits
,
gold
):
...
...
@@ -205,7 +249,7 @@ def build_cross_entropy_loss(logits, gold):
Args:
logits: float Tensor of scores.
gold: int Tensor of
one-hot
labels.
gold: int Tensor of
gold
label
id
s.
Returns:
cost, correct, total: the total cost, the total number of correctly
...
...
@@ -251,9 +295,10 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
"""
logging
.
info
(
'Building component: %s'
,
self
.
spec
.
name
)
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
self
.
network
.
pre_create
(
stride
)
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
state
.
handle
,
fixed_embeddings
=
fetch_differentiable_fixed_embeddings
(
self
,
state
,
stride
)
self
,
state
,
stride
,
True
)
linked_embeddings
=
[
fetch_linked_embedding
(
self
,
network_states
,
spec
)
...
...
@@ -307,14 +352,29 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
else
:
stride
=
state
.
current_batch_size
*
self
.
inference_beam_size
self
.
network
.
pre_create
(
stride
)
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
if
during_training
:
state
.
handle
,
fixed_embeddings
=
fetch_differentiable_fixed_embeddings
(
self
,
state
,
stride
)
self
,
state
,
stride
,
during_training
)
else
:
state
.
handle
,
fixed_embeddings
=
fetch_fast_fixed_embeddings
(
self
,
state
)
if
'use_densors'
in
self
.
spec
.
network_unit
.
parameters
:
state
.
handle
,
fixed_embeddings
=
fetch_dense_ragged_embeddings
(
self
,
state
)
else
:
if
(
'padded_batch_size'
in
self
.
spec
.
network_unit
.
parameters
and
'padded_sentence_length'
in
self
.
spec
.
network_unit
.
parameters
):
state
.
handle
,
fixed_embeddings
=
fetch_fast_fixed_embeddings
(
self
,
state
,
pad_to_batch
=-
1
,
pad_to_steps
=
int
(
self
.
spec
.
network_unit
.
parameters
[
'padded_sentence_length'
]))
else
:
state
.
handle
,
fixed_embeddings
=
fetch_fast_fixed_embeddings
(
self
,
state
)
linked_embeddings
=
[
fetch_linked_embedding
(
self
,
network_states
,
spec
)
...
...
@@ -331,6 +391,7 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
stride
=
stride
)
update_network_states
(
self
,
tensors
,
network_states
,
stride
)
self
.
_add_runtime_hooks
()
return
state
.
handle
...
...
@@ -367,7 +428,9 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
def
build_greedy_inference
(
self
,
state
,
network_states
,
during_training
=
False
):
"""See base class."""
return
self
.
_extract_feature_ids
(
state
,
network_states
,
during_training
)
handle
=
self
.
_extract_feature_ids
(
state
,
network_states
,
during_training
)
self
.
_add_runtime_hooks
()
return
handle
def
_extract_feature_ids
(
self
,
state
,
network_states
,
during_training
):
"""Extracts feature IDs and advances a batch using the oracle path.
...
...
@@ -387,6 +450,7 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
else
:
stride
=
state
.
current_batch_size
*
self
.
inference_beam_size
self
.
network
.
pre_create
(
stride
)
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
state
.
handle
,
ids
=
extract_fixed_feature_ids
(
self
,
state
,
stride
)
...
...
@@ -438,17 +502,21 @@ class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
]
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
self
.
network
.
pre_create
(
stride
)
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
network_tensors
=
self
.
network
.
create
([],
linked_embeddings
,
None
,
None
,
True
,
stride
)
update_network_states
(
self
,
network_tensors
,
network_states
,
stride
)
logits
=
self
.
network
.
get_logits
(
network_tensors
)
state
.
handle
,
gold
=
dragnn_ops
.
bulk_advance_from_oracle
(
state
.
handle
,
component
=
self
.
name
)
cost
,
correct
,
total
=
build_cross_entropy_loss
(
logits
,
gold
)
cost
,
correct
,
total
=
self
.
network
.
compute_bulk_loss
(
stride
,
network_tensors
,
gold
)
if
cost
is
None
:
# The network does not have a custom bulk loss; default to softmax.
logits
=
self
.
network
.
get_logits
(
network_tensors
)
cost
,
correct
,
total
=
build_cross_entropy_loss
(
logits
,
gold
)
cost
=
self
.
add_regularizer
(
cost
)
return
state
.
handle
,
cost
,
correct
,
total
...
...
@@ -483,13 +551,24 @@ class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
else
:
stride
=
state
.
current_batch_size
*
self
.
inference_beam_size
self
.
network
.
pre_create
(
stride
)
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
network_tensors
=
self
.
network
.
create
(
[],
linked_embeddings
,
None
,
None
,
during_training
,
stride
)
network_tensors
=
self
.
network
.
create
(
[],
linked_embeddings
,
None
,
None
,
during_training
,
stride
)
update_network_states
(
self
,
network_tensors
,
network_states
,
stride
)
logits
=
self
.
network
.
get_logits
(
network_tensors
)
return
dragnn_ops
.
bulk_advance_from_prediction
(
logits
=
self
.
network
.
get_bulk_predictions
(
stride
,
network_tensors
)
if
logits
is
None
:
# The network does not produce custom bulk predictions; default to logits.
logits
=
self
.
network
.
get_logits
(
network_tensors
)
logits
=
tf
.
cond
(
self
.
locally_normalize
,
lambda
:
tf
.
nn
.
log_softmax
(
logits
),
lambda
:
logits
)
if
self
.
_output_as_probabilities
:
logits
=
tf
.
nn
.
softmax
(
logits
)
handle
=
dragnn_ops
.
bulk_advance_from_prediction
(
state
.
handle
,
logits
,
component
=
self
.
name
)
self
.
_add_runtime_hooks
()
return
handle
research/syntaxnet/dragnn/python/bulk_component_test.py
View file @
80178fc6
...
...
@@ -41,8 +41,6 @@ from dragnn.python import dragnn_ops
from
dragnn.python
import
network_units
from
syntaxnet
import
sentence_pb2
FLAGS
=
tf
.
app
.
flags
.
FLAGS
class
MockNetworkUnit
(
object
):
...
...
@@ -63,6 +61,7 @@ class MockMaster(object):
self
.
spec
=
spec_pb2
.
MasterSpec
()
self
.
hyperparams
=
spec_pb2
.
GridPoint
()
self
.
lookup_component
=
{
'mock'
:
MockComponent
()}
self
.
build_runtime_graph
=
False
def
_create_fake_corpus
():
...
...
@@ -84,9 +83,12 @@ def _create_fake_corpus():
class
BulkComponentTest
(
test_util
.
TensorFlowTestCase
):
def
setUp
(
self
):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf
.
reset_default_graph
()
self
.
master
=
MockMaster
()
self
.
master_state
=
component
.
MasterState
(
handle
=
'handle'
,
current_batch_size
=
2
)
handle
=
tf
.
constant
([
'foo'
,
'bar'
])
,
current_batch_size
=
2
)
self
.
network_states
=
{
'mock'
:
component
.
NetworkState
(),
'test'
:
component
.
NetworkState
(),
...
...
@@ -107,22 +109,21 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
"""
,
component_spec
)
# For feature extraction:
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
# Expect feature extraction to generate a error due to the "history"
# translator.
with
self
.
assertRaises
(
NotImplementedError
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
# Expect feature extraction to generate a error due to the "history"
# translator.
with
self
.
assertRaises
(
NotImplementedError
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
# As well as annotation:
with
tf
.
Graph
().
as_default
()
:
comp
=
bulk_component
.
BulkAnnotatorComponentBuilder
(
self
.
master
,
component_spec
)
self
.
setUp
()
comp
=
bulk_component
.
BulkAnnotatorComponentBuilder
(
self
.
master
,
component_spec
)
with
self
.
assertRaises
(
NotImplementedError
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
with
self
.
assertRaises
(
NotImplementedError
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
def
testFailsOnRecurrentLinkedFeature
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
...
...
@@ -143,22 +144,21 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
"""
,
component_spec
)
# For feature extraction:
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
# Expect feature extraction to generate a error due to the "history"
# translator.
with
self
.
assertRaises
(
RuntimeError
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
# Expect feature extraction to generate a error due to the "history"
# translator.
with
self
.
assertRaises
(
RuntimeError
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
# As well as annotation:
with
tf
.
Graph
().
as_default
()
:
comp
=
bulk_component
.
BulkAnnotatorComponentBuilder
(
self
.
master
,
component_spec
)
self
.
setUp
()
comp
=
bulk_component
.
BulkAnnotatorComponentBuilder
(
self
.
master
,
component_spec
)
with
self
.
assertRaises
(
RuntimeError
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
with
self
.
assertRaises
(
RuntimeError
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
def
testConstantFixedFeatureFailsIfNotPretrained
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
...
...
@@ -175,21 +175,20 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
}
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
with
self
.
assertRaisesRegexp
(
ValueError
,
'Constant embeddings must be pretrained'
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
with
self
.
assertRaisesRegexp
(
ValueError
,
'Constant embeddings must be pretrained'
):
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
,
during_training
=
True
)
with
self
.
assertRaisesRegexp
(
ValueError
,
'Constant embeddings must be pretrained'
):
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
,
during_training
=
False
)
with
self
.
assertRaisesRegexp
(
ValueError
,
'Constant embeddings must be pretrained'
):
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
with
self
.
assertRaisesRegexp
(
ValueError
,
'Constant embeddings must be pretrained'
):
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
,
during_training
=
True
)
with
self
.
assertRaisesRegexp
(
ValueError
,
'Constant embeddings must be pretrained'
):
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
,
during_training
=
False
)
def
testNormalFixedFeaturesAreDifferentiable
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
...
...
@@ -207,25 +206,24 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
}
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
# Get embedding matrix variables.
with
tf
.
variable_scope
(
comp
.
name
,
reuse
=
True
):
fixed_embedding_matrix
=
tf
.
get_variable
(
network_units
.
fixed_embeddings_name
(
0
))
# Get embedding matrix variables.
with
tf
.
variable_scope
(
comp
.
name
,
reuse
=
True
):
fixed_embedding_matrix
=
tf
.
get_variable
(
network_units
.
fixed_embeddings_name
(
0
))
# Get output layer.
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
activations
=
self
.
network_states
[
comp
.
name
].
activations
outputs
=
activations
[
comp
.
network
.
layers
[
0
].
name
].
bulk_tensor
# Get output layer.
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
activations
=
self
.
network_states
[
comp
.
name
].
activations
outputs
=
activations
[
comp
.
network
.
layers
[
0
].
name
].
bulk_tensor
# Compute the gradient of the output layer w.r.t. the embedding matrix.
# This should be well-defined for in the normal case.
gradients
=
tf
.
gradients
(
outputs
,
fixed_embedding_matrix
)
self
.
assertEqual
(
len
(
gradients
),
1
)
self
.
assertFalse
(
gradients
[
0
]
is
None
)
# Compute the gradient of the output layer w.r.t. the embedding matrix.
# This should be well-defined for in the normal case.
gradients
=
tf
.
gradients
(
outputs
,
fixed_embedding_matrix
)
self
.
assertEqual
(
len
(
gradients
),
1
)
self
.
assertFalse
(
gradients
[
0
]
is
None
)
def
testConstantFixedFeaturesAreNotDifferentiableButOthersAre
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
...
...
@@ -249,31 +247,30 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
}
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
# Get embedding matrix variables.
with
tf
.
variable_scope
(
comp
.
name
,
reuse
=
True
):
constant_embedding_matrix
=
tf
.
get_variable
(
network_units
.
fixed_embeddings_name
(
0
))
trainable_embedding_matrix
=
tf
.
get_variable
(
network_units
.
fixed_embeddings_name
(
1
))
# Get output layer.
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
activations
=
self
.
network_states
[
comp
.
name
].
activations
outputs
=
activations
[
comp
.
network
.
layers
[
0
].
name
].
bulk_tensor
# The constant embeddings are non-differentiable.
constant_gradients
=
tf
.
gradients
(
outputs
,
constant_embedding_matrix
)
self
.
assertEqual
(
len
(
constant_gradients
),
1
)
self
.
assertTrue
(
constant_gradients
[
0
]
is
None
)
# The trainable embeddings are differentiable.
trainable_gradients
=
tf
.
gradients
(
outputs
,
trainable_embedding_matrix
)
self
.
assertEqual
(
len
(
trainable_gradients
),
1
)
self
.
assertFalse
(
trainable_gradients
[
0
]
is
None
)
comp
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
# Get embedding matrix variables.
with
tf
.
variable_scope
(
comp
.
name
,
reuse
=
True
):
constant_embedding_matrix
=
tf
.
get_variable
(
network_units
.
fixed_embeddings_name
(
0
))
trainable_embedding_matrix
=
tf
.
get_variable
(
network_units
.
fixed_embeddings_name
(
1
))
# Get output layer.
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
activations
=
self
.
network_states
[
comp
.
name
].
activations
outputs
=
activations
[
comp
.
network
.
layers
[
0
].
name
].
bulk_tensor
# The constant embeddings are non-differentiable.
constant_gradients
=
tf
.
gradients
(
outputs
,
constant_embedding_matrix
)
self
.
assertEqual
(
len
(
constant_gradients
),
1
)
self
.
assertTrue
(
constant_gradients
[
0
]
is
None
)
# The trainable embeddings are differentiable.
trainable_gradients
=
tf
.
gradients
(
outputs
,
trainable_embedding_matrix
)
self
.
assertEqual
(
len
(
trainable_gradients
),
1
)
self
.
assertFalse
(
trainable_gradients
[
0
]
is
None
)
def
testFailsOnFixedFeature
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
...
...
@@ -306,15 +303,14 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
name: "fixed" embedding_dim: -1 size: 1
}
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
# Should not raise errors.
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
)
# Should not raise errors.
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
)
def
testBulkFeatureIdExtractorFailsOnLinkedFeature
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
...
...
@@ -332,10 +328,9 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
source_component: "mock"
}
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
with
self
.
assertRaises
(
ValueError
):
unused_comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
with
self
.
assertRaises
(
ValueError
):
unused_comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
def
testBulkFeatureIdExtractorOkWithMultipleFixedFeatures
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
...
...
@@ -354,15 +349,14 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
name: "fixed3" embedding_dim: -1 size: 1
}
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
# Should not raise errors.
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
)
# Should not raise errors.
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
self
.
network_states
[
component_spec
.
name
]
=
component
.
NetworkState
()
comp
.
build_greedy_inference
(
self
.
master_state
,
self
.
network_states
)
def
testBulkFeatureIdExtractorFailsOnEmbeddedFixedFeature
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
...
...
@@ -375,10 +369,9 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
name: "fixed" embedding_dim: 2 size: 1
}
"""
,
component_spec
)
with
tf
.
Graph
().
as_default
():
with
self
.
assertRaises
(
ValueError
):
unused_comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
with
self
.
assertRaises
(
ValueError
):
unused_comp
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
def
testBulkFeatureIdExtractorExtractFocusWithOffset
(
self
):
path
=
os
.
path
.
join
(
tf
.
test
.
get_temp_dir
(),
'label-map'
)
...
...
@@ -420,67 +413,131 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
}
"""
%
path
,
master_spec
)
with
tf
.
Graph
().
as_default
():
corpus
=
_create_fake_corpus
()
corpus
=
tf
.
constant
(
corpus
,
shape
=
[
len
(
corpus
)])
handle
=
dragnn_ops
.
get_session
(
container
=
'test'
,
master_spec
=
master_spec
.
SerializeToString
(),
grid_point
=
''
)
handle
=
dragnn_ops
.
attach_data_reader
(
handle
,
corpus
)
handle
=
dragnn_ops
.
init_component_data
(
handle
,
beam_size
=
1
,
component
=
'test'
)
batch_size
=
dragnn_ops
.
batch_size
(
handle
,
component
=
'test'
)
master_state
=
component
.
MasterState
(
handle
,
batch_size
)
extractor
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
master_spec
.
component
[
0
])
network_state
=
component
.
NetworkState
()
self
.
network_states
[
'test'
]
=
network_state
handle
=
extractor
.
build_greedy_inference
(
master_state
,
self
.
network_states
)
focus1
=
network_state
.
activations
[
'focus1'
].
bulk_tensor
focus2
=
network_state
.
activations
[
'focus2'
].
bulk_tensor
focus3
=
network_state
.
activations
[
'focus3'
].
bulk_tensor
with
self
.
test_session
()
as
sess
:
focus1
,
focus2
,
focus3
=
sess
.
run
([
focus1
,
focus2
,
focus3
])
tf
.
logging
.
info
(
'focus1=
\n
%s'
,
focus1
)
tf
.
logging
.
info
(
'focus2=
\n
%s'
,
focus2
)
tf
.
logging
.
info
(
'focus3=
\n
%s'
,
focus3
)
self
.
assertAllEqual
(
focus1
,
[[
0
],
[
-
1
],
[
-
1
],
[
-
1
],
[
0
],
[
1
],
[
-
1
],
[
-
1
],
[
0
],
[
1
],
[
2
],
[
-
1
],
[
0
],
[
1
],
[
2
],
[
3
]])
self
.
assertAllEqual
(
focus2
,
[[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
1
],
[
2
],
[
-
1
],
[
-
1
],
[
1
],
[
2
],
[
3
],
[
-
1
]])
self
.
assertAllEqual
(
focus3
,
[[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
2
],
[
-
1
],
[
-
1
],
[
-
1
],
[
2
],
[
3
],
[
-
1
],
[
-
1
]])
corpus
=
_create_fake_corpus
()
corpus
=
tf
.
constant
(
corpus
,
shape
=
[
len
(
corpus
)])
handle
=
dragnn_ops
.
get_session
(
container
=
'test'
,
master_spec
=
master_spec
.
SerializeToString
(),
grid_point
=
''
)
handle
=
dragnn_ops
.
attach_data_reader
(
handle
,
corpus
)
handle
=
dragnn_ops
.
init_component_data
(
handle
,
beam_size
=
1
,
component
=
'test'
)
batch_size
=
dragnn_ops
.
batch_size
(
handle
,
component
=
'test'
)
master_state
=
component
.
MasterState
(
handle
,
batch_size
)
extractor
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
master_spec
.
component
[
0
])
network_state
=
component
.
NetworkState
()
self
.
network_states
[
'test'
]
=
network_state
handle
=
extractor
.
build_greedy_inference
(
master_state
,
self
.
network_states
)
focus1
=
network_state
.
activations
[
'focus1'
].
bulk_tensor
focus2
=
network_state
.
activations
[
'focus2'
].
bulk_tensor
focus3
=
network_state
.
activations
[
'focus3'
].
bulk_tensor
with
self
.
test_session
()
as
sess
:
focus1
,
focus2
,
focus3
=
sess
.
run
([
focus1
,
focus2
,
focus3
])
tf
.
logging
.
info
(
'focus1=
\n
%s'
,
focus1
)
tf
.
logging
.
info
(
'focus2=
\n
%s'
,
focus2
)
tf
.
logging
.
info
(
'focus3=
\n
%s'
,
focus3
)
self
.
assertAllEqual
(
focus1
,
[[
0
],
[
-
1
],
[
-
1
],
[
-
1
],
[
0
],
[
1
],
[
-
1
],
[
-
1
],
[
0
],
[
1
],
[
2
],
[
-
1
],
[
0
],
[
1
],
[
2
],
[
3
]])
# pyformat: disable
self
.
assertAllEqual
(
focus2
,
[[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
1
],
[
2
],
[
-
1
],
[
-
1
],
[
1
],
[
2
],
[
3
],
[
-
1
]])
# pyformat: disable
self
.
assertAllEqual
(
focus3
,
[[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
-
1
],
[
2
],
[
-
1
],
[
-
1
],
[
-
1
],
[
2
],
[
3
],
[
-
1
],
[
-
1
]])
# pyformat: disable
def
testBuildLossFailsOnNoExamples
(
self
):
with
tf
.
Graph
().
as_default
():
logits
=
tf
.
constant
([[
0.5
],
[
-
0.5
],
[
0.5
],
[
-
0.5
]])
gold
=
tf
.
constant
([
-
1
,
-
1
,
-
1
,
-
1
])
result
=
bulk_component
.
build_cross_entropy_loss
(
logits
,
gold
)
# Expect loss computation to generate a runtime error due to the gold
# tensor containing no valid examples.
with
self
.
test_session
()
as
sess
:
with
self
.
assertRaises
(
tf
.
errors
.
InvalidArgumentError
):
sess
.
run
(
result
)
logits
=
tf
.
constant
([[
0.5
],
[
-
0.5
],
[
0.5
],
[
-
0.5
]])
gold
=
tf
.
constant
([
-
1
,
-
1
,
-
1
,
-
1
])
result
=
bulk_component
.
build_cross_entropy_loss
(
logits
,
gold
)
# Expect loss computation to generate a runtime error due to the gold
# tensor containing no valid examples.
with
self
.
test_session
()
as
sess
:
with
self
.
assertRaises
(
tf
.
errors
.
InvalidArgumentError
):
sess
.
run
(
result
)
def
testPreCreateCalledBeforeCreate
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
text_format
.
Parse
(
"""
name: "test"
network_unit {
registered_name: "IdentityNetwork"
}
"""
,
component_spec
)
class
AssertPreCreateBeforeCreateNetwork
(
network_units
.
NetworkUnitInterface
):
"""Mock that asserts that .create() is called before .pre_create()."""
def
__init__
(
self
,
comp
,
test_fixture
):
super
(
AssertPreCreateBeforeCreateNetwork
,
self
).
__init__
(
comp
)
self
.
_test_fixture
=
test_fixture
self
.
_pre_create_called
=
False
def
get_logits
(
self
,
network_tensors
):
return
tf
.
zeros
([
2
,
1
],
dtype
=
tf
.
float32
)
def
pre_create
(
self
,
*
unused_args
):
self
.
_pre_create_called
=
True
def
create
(
self
,
*
unused_args
,
**
unuesd_kwargs
):
self
.
_test_fixture
.
assertTrue
(
self
.
_pre_create_called
)
return
[]
builder
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
builder
.
network
=
AssertPreCreateBeforeCreateNetwork
(
builder
,
self
)
builder
.
build_greedy_training
(
component
.
MasterState
([
'foo'
,
'bar'
],
2
),
self
.
network_states
)
self
.
setUp
()
builder
=
bulk_component
.
BulkFeatureExtractorComponentBuilder
(
self
.
master
,
component_spec
)
builder
.
network
=
AssertPreCreateBeforeCreateNetwork
(
builder
,
self
)
builder
.
build_greedy_inference
(
component
.
MasterState
([
'foo'
,
'bar'
],
2
),
self
.
network_states
)
self
.
setUp
()
builder
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
builder
.
network
=
AssertPreCreateBeforeCreateNetwork
(
builder
,
self
)
builder
.
build_greedy_training
(
component
.
MasterState
([
'foo'
,
'bar'
],
2
),
self
.
network_states
)
self
.
setUp
()
builder
=
bulk_component
.
BulkFeatureIdExtractorComponentBuilder
(
self
.
master
,
component_spec
)
builder
.
network
=
AssertPreCreateBeforeCreateNetwork
(
builder
,
self
)
builder
.
build_greedy_inference
(
component
.
MasterState
([
'foo'
,
'bar'
],
2
),
self
.
network_states
)
self
.
setUp
()
builder
=
bulk_component
.
BulkAnnotatorComponentBuilder
(
self
.
master
,
component_spec
)
builder
.
network
=
AssertPreCreateBeforeCreateNetwork
(
builder
,
self
)
builder
.
build_greedy_training
(
component
.
MasterState
([
'foo'
,
'bar'
],
2
),
self
.
network_states
)
self
.
setUp
()
builder
=
bulk_component
.
BulkAnnotatorComponentBuilder
(
self
.
master
,
component_spec
)
builder
.
network
=
AssertPreCreateBeforeCreateNetwork
(
builder
,
self
)
builder
.
build_greedy_inference
(
component
.
MasterState
([
'foo'
,
'bar'
],
2
),
self
.
network_states
)
if
__name__
==
'__main__'
:
googletest
.
main
()
research/syntaxnet/dragnn/python/component.py
View file @
80178fc6
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds a DRAGNN graph for local training."""
from
abc
import
ABCMeta
...
...
@@ -21,12 +20,79 @@ from abc import abstractmethod
import
tensorflow
as
tf
from
tensorflow.python.platform
import
tf_logging
as
logging
from
dragnn.protos
import
export_pb2
from
dragnn.python
import
dragnn_ops
from
dragnn.python
import
network_units
from
dragnn.python
import
runtime_support
from
syntaxnet.util
import
check
from
syntaxnet.util
import
registry
def
build_softmax_cross_entropy_loss
(
logits
,
gold
):
"""Builds softmax cross entropy loss."""
# A gold label > -1 determines that the sentence is still
# in a valid state. Otherwise, the sentence has ended.
#
# We add only the valid sentences to the loss, in the following way:
# 1. We compute 'valid_ix', the indices in gold that contain
# valid oracle actions.
# 2. We compute the cost function by comparing logits and gold
# only for the valid indices.
valid
=
tf
.
greater
(
gold
,
-
1
)
valid_ix
=
tf
.
reshape
(
tf
.
where
(
valid
),
[
-
1
])
valid_gold
=
tf
.
gather
(
gold
,
valid_ix
)
valid_logits
=
tf
.
gather
(
logits
,
valid_ix
)
cost
=
tf
.
reduce_sum
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
tf
.
cast
(
valid_gold
,
tf
.
int64
),
logits
=
valid_logits
,
name
=
'sparse_softmax_cross_entropy_with_logits'
))
correct
=
tf
.
reduce_sum
(
tf
.
to_int32
(
tf
.
nn
.
in_top_k
(
valid_logits
,
valid_gold
,
1
)))
total
=
tf
.
size
(
valid_gold
)
return
cost
,
correct
,
total
,
valid_logits
,
valid_gold
def
build_sigmoid_cross_entropy_loss
(
logits
,
gold
,
indices
,
probs
):
"""Builds sigmoid cross entropy loss."""
# Filter out entries where gold <= -1, which are batch padding entries.
valid
=
tf
.
greater
(
gold
,
-
1
)
valid_ix
=
tf
.
reshape
(
tf
.
where
(
valid
),
[
-
1
])
valid_gold
=
tf
.
gather
(
gold
,
valid_ix
)
valid_indices
=
tf
.
gather
(
indices
,
valid_ix
)
valid_probs
=
tf
.
gather
(
probs
,
valid_ix
)
# NB: tf.gather_nd() raises an error on CPU for out-of-bounds indices. That's
# why we need to filter out the gold=-1 batch padding above.
valid_pairs
=
tf
.
stack
([
valid_indices
,
valid_gold
],
axis
=
1
)
valid_logits
=
tf
.
gather_nd
(
logits
,
valid_pairs
)
cost
=
tf
.
reduce_sum
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
valid_probs
,
logits
=
valid_logits
,
name
=
'sigmoid_cross_entropy_with_logits'
))
gold_bool
=
valid_probs
>
0.5
predicted_bool
=
valid_logits
>
0.0
total
=
tf
.
size
(
gold_bool
)
with
tf
.
control_dependencies
([
tf
.
assert_equal
(
total
,
tf
.
size
(
predicted_bool
),
name
=
'num_predicted_gold_mismatch'
)
]):
agreement_bool
=
tf
.
logical_not
(
tf
.
logical_xor
(
gold_bool
,
predicted_bool
))
correct
=
tf
.
reduce_sum
(
tf
.
to_int32
(
agreement_bool
))
cost
.
set_shape
([])
correct
.
set_shape
([])
total
.
set_shape
([])
return
cost
,
correct
,
total
,
gold
class
NetworkState
(
object
):
"""Simple utility to manage the state of a DRAGNN network.
...
...
@@ -69,6 +135,13 @@ class ComponentBuilderBase(object):
As part of the specification, ComponentBuilder will wrap an underlying
NetworkUnit which generates the actual network layout.
Attributes:
master: dragnn.MasterBuilder that owns this component.
num_actions: Number of actions in the transition system.
name: Name of this component.
spec: dragnn.ComponentSpec that configures this component.
moving_average: True if moving-average parameters should be used.
"""
__metaclass__
=
ABCMeta
# required for @abstractmethod
...
...
@@ -96,16 +169,23 @@ class ComponentBuilderBase(object):
# Extract component attributes before make_network(), so the network unit
# can access them.
self
.
_attrs
=
{}
global_attr_defaults
=
{
'locally_normalize'
:
False
,
'output_as_probabilities'
:
False
}
if
attr_defaults
:
self
.
_attrs
=
network_units
.
get_attrs_with_defaults
(
self
.
spec
.
component_builder
.
parameters
,
attr_defaults
)
global_attr_defaults
.
update
(
attr_defaults
)
self
.
_attrs
=
network_units
.
get_attrs_with_defaults
(
self
.
spec
.
component_builder
.
parameters
,
global_attr_defaults
)
do_local_norm
=
self
.
_attrs
[
'locally_normalize'
]
self
.
_output_as_probabilities
=
self
.
_attrs
[
'output_as_probabilities'
]
with
tf
.
variable_scope
(
self
.
name
):
self
.
training_beam_size
=
tf
.
constant
(
self
.
spec
.
training_beam_size
,
name
=
'TrainingBeamSize'
)
self
.
inference_beam_size
=
tf
.
constant
(
self
.
spec
.
inference_beam_size
,
name
=
'InferenceBeamSize'
)
self
.
locally_normalize
=
tf
.
constant
(
False
,
name
=
'LocallyNormalize'
)
self
.
locally_normalize
=
tf
.
constant
(
do_local_norm
,
name
=
'LocallyNormalize'
)
self
.
_step
=
tf
.
get_variable
(
'step'
,
[],
initializer
=
tf
.
zeros_initializer
(),
dtype
=
tf
.
int32
)
self
.
_total
=
tf
.
get_variable
(
...
...
@@ -120,6 +200,9 @@ class ComponentBuilderBase(object):
decay
=
self
.
master
.
hyperparams
.
average_weight
,
num_updates
=
self
.
_step
)
self
.
avg_ops
=
[
self
.
moving_average
.
apply
(
self
.
network
.
params
)]
# Used to export the cell; see add_cell_input() and add_cell_output().
self
.
_cell_subgraph_spec
=
export_pb2
.
CellSubgraphSpec
()
def
make_network
(
self
,
network_unit
):
"""Makes a NetworkUnitInterface object based on the network_unit spec.
...
...
@@ -276,7 +359,7 @@ class ComponentBuilderBase(object):
Returns:
tf.Variable object corresponding to original or averaged version.
"""
if
var_params
:
if
var_params
is
not
None
:
var_name
=
var_params
.
name
else
:
check
.
NotNone
(
var_name
,
'specify at least one of var_name or var_params'
)
...
...
@@ -341,6 +424,79 @@ class ComponentBuilderBase(object):
"""Returns the value of the component attribute with the |name|."""
return
self
.
_attrs
[
name
]
def
has_attr
(
self
,
name
):
"""Checks whether a component attribute with the given |name| exists.
Arguments:
name: attribute name
Returns:
'true' if the name exists and 'false' otherwise.
"""
return
name
in
self
.
_attrs
def
_add_runtime_hooks
(
self
):
"""Adds "hook" nodes to the graph for use by the runtime, if enabled.
Does nothing if master.build_runtime_graph is False. Subclasses should call
this at the end of build_*_inference(). For details on the runtime hooks,
see runtime_support.py.
"""
if
self
.
master
.
build_runtime_graph
:
with
tf
.
variable_scope
(
self
.
name
,
reuse
=
True
):
runtime_support
.
add_hooks
(
self
,
self
.
_cell_subgraph_spec
)
self
.
_cell_subgraph_spec
=
None
# prevent further exports
def
add_cell_input
(
self
,
dtype
,
shape
,
name
,
input_type
=
'TYPE_FEATURE'
):
"""Adds an input to the current CellSubgraphSpec.
Creates a tf.placeholder() with the given |dtype| and |shape|, adds it as a
cell input with the |name| and |input_type|, and returns the placeholder to
be used in place of the actual input tensor.
Args:
dtype: DType of the cell input.
shape: Static shape of the cell input.
name: Logical name of the cell input.
input_type: Name of the appropriate CellSubgraphSpec.Input.Type enum.
Returns:
A tensor to use in place of the actual input tensor.
Raises:
TypeError: If the |shape| is the wrong type.
RuntimeError: If the cell has already been exported.
"""
if
not
(
isinstance
(
shape
,
list
)
and
all
(
isinstance
(
dim
,
int
)
for
dim
in
shape
)):
raise
TypeError
(
'shape must be a list of int'
)
if
not
self
.
_cell_subgraph_spec
:
raise
RuntimeError
(
'already exported a CellSubgraphSpec'
)
with
tf
.
name_scope
(
None
):
tensor
=
tf
.
placeholder
(
dtype
,
shape
,
name
=
'{}/INPUT/{}'
.
format
(
self
.
name
,
name
))
self
.
_cell_subgraph_spec
.
input
.
add
(
name
=
name
,
tensor
=
tensor
.
name
,
type
=
export_pb2
.
CellSubgraphSpec
.
Input
.
Type
.
Value
(
input_type
))
return
tensor
def
add_cell_output
(
self
,
tensor
,
name
):
"""Adds an output to the current CellSubgraphSpec.
Args:
tensor: Tensor to add as a cell output.
name: Logical name of the cell output.
Raises:
RuntimeError: If the cell has already been exported.
"""
if
not
self
.
_cell_subgraph_spec
:
raise
RuntimeError
(
'already exported a CellSubgraphSpec'
)
self
.
_cell_subgraph_spec
.
output
.
add
(
name
=
name
,
tensor
=
tensor
.
name
)
def
update_tensor_arrays
(
network_tensors
,
arrays
):
"""Updates a list of tensor arrays from the network's output tensors.
...
...
@@ -370,6 +526,18 @@ class DynamicComponentBuilder(ComponentBuilderBase):
so fixed and linked features can be recurrent.
"""
def
__init__
(
self
,
master
,
component_spec
):
"""Initializes the DynamicComponentBuilder from specifications.
Args:
master: dragnn.MasterBuilder object.
component_spec: dragnn.ComponentSpec proto to be built.
"""
super
(
DynamicComponentBuilder
,
self
).
__init__
(
master
,
component_spec
,
attr_defaults
=
{
'loss_function'
:
'softmax_cross_entropy'
})
def
build_greedy_training
(
self
,
state
,
network_states
):
"""Builds a training loop for this component.
...
...
@@ -392,9 +560,10 @@ class DynamicComponentBuilder(ComponentBuilderBase):
# Add 0 to training_beam_size to disable eager static evaluation.
# This is possible because tensorflow's constant_value does not
# propagate arithmetic operations.
with
tf
.
control_dependencies
(
[
tf
.
assert_equal
(
self
.
training_beam_size
+
0
,
1
)]):
with
tf
.
control_dependencies
(
[
tf
.
assert_equal
(
self
.
training_beam_size
+
0
,
1
)]):
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
self
.
network
.
pre_create
(
stride
)
cost
=
tf
.
constant
(
0.
)
correct
=
tf
.
constant
(
0
)
...
...
@@ -416,40 +585,35 @@ class DynamicComponentBuilder(ComponentBuilderBase):
# Every layer is written to a TensorArray, so that it can be backprop'd.
next_arrays
=
update_tensor_arrays
(
network_tensors
,
arrays
)
loss_function
=
self
.
attr
(
'loss_function'
)
with
tf
.
control_dependencies
([
x
.
flow
for
x
in
next_arrays
]):
with
tf
.
name_scope
(
'compute_loss'
):
# A gold label > -1 determines that the sentence is still
# in a valid state. Otherwise, the sentence has ended.
#
# We add only the valid sentences to the loss, in the following way:
# 1. We compute 'valid_ix', the indices in gold that contain
# valid oracle actions.
# 2. We compute the cost function by comparing logits and gold
# only for the valid indices.
gold
=
dragnn_ops
.
emit_oracle_labels
(
handle
,
component
=
self
.
name
)
gold
.
set_shape
([
None
])
valid
=
tf
.
greater
(
gold
,
-
1
)
valid_ix
=
tf
.
reshape
(
tf
.
where
(
valid
),
[
-
1
])
gold
=
tf
.
gather
(
gold
,
valid_ix
)
logits
=
self
.
network
.
get_logits
(
network_tensors
)
logits
=
tf
.
gather
(
logits
,
valid_ix
)
cost
+=
tf
.
reduce_sum
(
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
tf
.
cast
(
gold
,
tf
.
int64
),
logits
=
logits
))
if
(
self
.
eligible_for_self_norm
and
self
.
master
.
hyperparams
.
self_norm_alpha
>
0
):
log_z
=
tf
.
reduce_logsumexp
(
logits
,
[
1
])
cost
+=
(
self
.
master
.
hyperparams
.
self_norm_alpha
*
tf
.
nn
.
l2_loss
(
log_z
))
correct
+=
tf
.
reduce_sum
(
tf
.
to_int32
(
tf
.
nn
.
in_top_k
(
logits
,
gold
,
1
)))
total
+=
tf
.
size
(
gold
)
with
tf
.
control_dependencies
([
cost
,
correct
,
total
,
gold
]):
if
loss_function
==
'softmax_cross_entropy'
:
gold
=
dragnn_ops
.
emit_oracle_labels
(
handle
,
component
=
self
.
name
)
new_cost
,
new_correct
,
new_total
,
valid_logits
,
valid_gold
=
(
build_softmax_cross_entropy_loss
(
logits
,
gold
))
if
(
self
.
eligible_for_self_norm
and
self
.
master
.
hyperparams
.
self_norm_alpha
>
0
):
log_z
=
tf
.
reduce_logsumexp
(
valid_logits
,
[
1
])
new_cost
+=
(
self
.
master
.
hyperparams
.
self_norm_alpha
*
tf
.
nn
.
l2_loss
(
log_z
))
elif
loss_function
==
'sigmoid_cross_entropy'
:
indices
,
gold
,
probs
=
(
dragnn_ops
.
emit_oracle_labels_and_probabilities
(
handle
,
component
=
self
.
name
))
new_cost
,
new_correct
,
new_total
,
valid_gold
=
(
build_sigmoid_cross_entropy_loss
(
logits
,
gold
,
indices
,
probs
))
else
:
RuntimeError
(
"Unknown loss function '%s'"
%
loss_function
)
cost
+=
new_cost
correct
+=
new_correct
total
+=
new_total
with
tf
.
control_dependencies
([
cost
,
correct
,
total
,
valid_gold
]):
handle
=
dragnn_ops
.
advance_from_oracle
(
handle
,
component
=
self
.
name
)
return
[
handle
,
cost
,
correct
,
total
]
+
next_arrays
...
...
@@ -480,6 +644,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
# Normalize the objective by the total # of steps taken.
# Note: Total could be zero by a number of reasons, including:
# * Oracle labels not being emitted.
# * All oracle labels for a batch are unknown (-1).
# * No steps being taken if component is terminal at the start of a batch.
with
tf
.
control_dependencies
([
tf
.
assert_greater
(
total
,
0
)]):
cost
/=
tf
.
to_float
(
total
)
...
...
@@ -511,6 +676,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
stride
=
state
.
current_batch_size
*
self
.
training_beam_size
else
:
stride
=
state
.
current_batch_size
*
self
.
inference_beam_size
self
.
network
.
pre_create
(
stride
)
def
cond
(
handle
,
*
_
):
all_final
=
dragnn_ops
.
emit_all_final
(
handle
,
component
=
self
.
name
)
...
...
@@ -559,6 +725,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
for
index
,
layer
in
enumerate
(
self
.
network
.
layers
):
network_state
.
activations
[
layer
.
name
]
=
network_units
.
StoredActivations
(
array
=
arrays
[
index
])
self
.
_add_runtime_hooks
()
with
tf
.
control_dependencies
([
x
.
flow
for
x
in
arrays
]):
return
tf
.
identity
(
state
.
handle
)
...
...
@@ -587,7 +754,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
fixed_embeddings
=
[]
for
channel_id
,
feature_spec
in
enumerate
(
self
.
spec
.
fixed_feature
):
fixed_embedding
=
network_units
.
fixed_feature_lookup
(
self
,
state
,
channel_id
,
stride
)
self
,
state
,
channel_id
,
stride
,
during_training
)
if
feature_spec
.
is_constant
:
fixed_embedding
.
tensor
=
tf
.
stop_gradient
(
fixed_embedding
.
tensor
)
fixed_embeddings
.
append
(
fixed_embedding
)
...
...
@@ -633,6 +800,12 @@ class DynamicComponentBuilder(ComponentBuilderBase):
else
:
attention_tensor
=
None
return
self
.
network
.
create
(
fixed_embeddings
,
linked_embeddings
,
context_tensor_arrays
,
attention_tensor
,
during_training
)
tensors
=
self
.
network
.
create
(
fixed_embeddings
,
linked_embeddings
,
context_tensor_arrays
,
attention_tensor
,
during_training
)
if
self
.
master
.
build_runtime_graph
:
for
index
,
layer
in
enumerate
(
self
.
network
.
layers
):
self
.
add_cell_output
(
tensors
[
index
],
layer
.
name
)
return
tensors
research/syntaxnet/dragnn/python/component_test.py
0 → 100644
View file @
80178fc6
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for component.py.
"""
import
tensorflow
as
tf
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
googletest
from
google.protobuf
import
text_format
from
dragnn.protos
import
spec_pb2
from
dragnn.python
import
component
class
MockNetworkUnit
(
object
):
def
get_layer_size
(
self
,
unused_layer_name
):
return
64
class
MockComponent
(
object
):
def
__init__
(
self
):
self
.
name
=
'mock'
self
.
network
=
MockNetworkUnit
()
class
MockMaster
(
object
):
def
__init__
(
self
):
self
.
spec
=
spec_pb2
.
MasterSpec
()
self
.
hyperparams
=
spec_pb2
.
GridPoint
()
self
.
lookup_component
=
{
'mock'
:
MockComponent
()}
self
.
build_runtime_graph
=
False
class
ComponentTest
(
test_util
.
TensorFlowTestCase
):
def
setUp
(
self
):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf
.
reset_default_graph
()
self
.
master
=
MockMaster
()
self
.
master_state
=
component
.
MasterState
(
handle
=
tf
.
constant
([
'foo'
,
'bar'
]),
current_batch_size
=
2
)
self
.
network_states
=
{
'mock'
:
component
.
NetworkState
(),
'test'
:
component
.
NetworkState
(),
}
def
testSoftmaxCrossEntropyLoss
(
self
):
logits
=
tf
.
constant
([[
0.0
,
2.0
,
-
1.0
],
[
-
5.0
,
1.0
,
-
1.0
],
[
3.0
,
1.0
,
-
2.0
]])
# pyformat: disable
gold_labels
=
tf
.
constant
([
1
,
-
1
,
1
])
cost
,
correct
,
total
,
logits
,
gold_labels
=
(
component
.
build_softmax_cross_entropy_loss
(
logits
,
gold_labels
))
with
self
.
test_session
()
as
sess
:
cost
,
correct
,
total
,
logits
,
gold_labels
=
(
sess
.
run
([
cost
,
correct
,
total
,
logits
,
gold_labels
]))
# Cost = -2 + ln(1 + exp(2) + exp(-1))
# -1 + ln(exp(3) + exp(1) + exp(-2))
self
.
assertAlmostEqual
(
cost
,
2.3027
,
4
)
self
.
assertEqual
(
correct
,
1
)
self
.
assertEqual
(
total
,
2
)
# Entries corresponding to gold labels equal to -1 are skipped.
self
.
assertAllEqual
(
logits
,
[[
0.0
,
2.0
,
-
1.0
],
[
3.0
,
1.0
,
-
2.0
]])
self
.
assertAllEqual
(
gold_labels
,
[
1
,
1
])
def
testSigmoidCrossEntropyLoss
(
self
):
indices
=
tf
.
constant
([
0
,
0
,
1
])
gold_labels
=
tf
.
constant
([
0
,
1
,
2
])
probs
=
tf
.
constant
([
0.6
,
0.7
,
0.2
])
logits
=
tf
.
constant
([[
0.9
,
-
0.3
,
0.1
],
[
-
0.5
,
0.4
,
2.0
]])
cost
,
correct
,
total
,
gold_labels
=
(
component
.
build_sigmoid_cross_entropy_loss
(
logits
,
gold_labels
,
indices
,
probs
))
with
self
.
test_session
()
as
sess
:
cost
,
correct
,
total
,
gold_labels
=
(
sess
.
run
([
cost
,
correct
,
total
,
gold_labels
]))
# The cost corresponding to the three entries is, respectively,
# 0.7012, 0.7644, and 1.7269. Each of them is computed using the formula
# -prob_i * log(sigmoid(logit_i)) - (1-prob_i) * log(1-sigmoid(logit_i))
self
.
assertAlmostEqual
(
cost
,
3.1924
,
4
)
self
.
assertEqual
(
correct
,
1
)
self
.
assertEqual
(
total
,
3
)
self
.
assertAllEqual
(
gold_labels
,
[
0
,
1
,
2
])
def
testGraphConstruction
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
text_format
.
Parse
(
"""
name: "test"
network_unit {
registered_name: "IdentityNetwork"
}
fixed_feature {
name: "fixed" embedding_dim: 32 size: 1
}
component_builder {
registered_name: "component.DynamicComponentBuilder"
}
"""
,
component_spec
)
comp
=
component
.
DynamicComponentBuilder
(
self
.
master
,
component_spec
)
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
def
testGraphConstructionWithSigmoidLoss
(
self
):
component_spec
=
spec_pb2
.
ComponentSpec
()
text_format
.
Parse
(
"""
name: "test"
network_unit {
registered_name: "IdentityNetwork"
}
fixed_feature {
name: "fixed" embedding_dim: 32 size: 1
}
component_builder {
registered_name: "component.DynamicComponentBuilder"
parameters {
key: "loss_function"
value: "sigmoid_cross_entropy"
}
}
"""
,
component_spec
)
comp
=
component
.
DynamicComponentBuilder
(
self
.
master
,
component_spec
)
comp
.
build_greedy_training
(
self
.
master_state
,
self
.
network_states
)
# Check that the loss op is present.
op_names
=
[
op
.
name
for
op
in
tf
.
get_default_graph
().
get_operations
()]
self
.
assertTrue
(
'train_test/compute_loss/'
'sigmoid_cross_entropy_with_logits'
in
op_names
)
if
__name__
==
'__main__'
:
googletest
.
main
()
research/syntaxnet/dragnn/python/digraph_ops.py
View file @
80178fc6
...
...
@@ -15,6 +15,10 @@
"""TensorFlow ops for directed graphs."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
syntaxnet.util
import
check
...
...
@@ -150,7 +154,7 @@ def ArcSourcePotentialsFromTokens(tokens, weights):
return
sources_bxnxn
def
RootPotentialsFromTokens
(
root
,
tokens
,
weights
):
def
RootPotentialsFromTokens
(
root
,
tokens
,
weights
_arc
,
weights_source
):
r
"""Returns root selection potentials computed from tokens and weights.
For each batch of token activations, computes a scalar potential for each root
...
...
@@ -162,7 +166,8 @@ def RootPotentialsFromTokens(root, tokens, weights):
Args:
root: [S] vector of activations for the artificial root token.
tokens: [B,N,T] tensor of batched activations for root tokens.
weights: [S,T] matrix of weights.
weights_arc: [S,T] matrix of weights.
weights_source: [S] vector of weights.
B,N may be statically-unknown, but S,T must be statically-known. The dtype
of all arguments must be compatible.
...
...
@@ -174,25 +179,30 @@ def RootPotentialsFromTokens(root, tokens, weights):
# All arguments must have statically-known rank.
check
.
Eq
(
root
.
get_shape
().
ndims
,
1
,
'root must be a vector'
)
check
.
Eq
(
tokens
.
get_shape
().
ndims
,
3
,
'tokens must be rank 3'
)
check
.
Eq
(
weights
.
get_shape
().
ndims
,
2
,
'weights must be a matrix'
)
check
.
Eq
(
weights_arc
.
get_shape
().
ndims
,
2
,
'weights_arc must be a matrix'
)
check
.
Eq
(
weights_source
.
get_shape
().
ndims
,
1
,
'weights_source must be a vector'
)
# All activation dimensions must be statically-known.
num_source_activations
=
weights
.
get_shape
().
as_list
()[
0
]
num_target_activations
=
weights
.
get_shape
().
as_list
()[
1
]
num_source_activations
=
weights
_arc
.
get_shape
().
as_list
()[
0
]
num_target_activations
=
weights
_arc
.
get_shape
().
as_list
()[
1
]
check
.
NotNone
(
num_source_activations
,
'unknown source activation dimension'
)
check
.
NotNone
(
num_target_activations
,
'unknown target activation dimension'
)
check
.
Eq
(
root
.
get_shape
().
as_list
()[
0
],
num_source_activations
,
'dimension mismatch between weights and root'
)
'dimension mismatch between weights
_arc
and root'
)
check
.
Eq
(
tokens
.
get_shape
().
as_list
()[
2
],
num_target_activations
,
'dimension mismatch between weights and tokens'
)
'dimension mismatch between weights_arc and tokens'
)
check
.
Eq
(
weights_source
.
get_shape
().
as_list
()[
0
],
num_source_activations
,
'dimension mismatch between weights_arc and weights_source'
)
# All arguments must share the same type.
check
.
Same
([
weights
.
dtype
.
base_dtype
,
root
.
dtype
.
base_dtype
,
tokens
.
dtype
.
base_dtype
],
'dtype mismatch'
)
check
.
Same
([
weights_arc
.
dtype
.
base_dtype
,
weights_source
.
dtype
.
base_dtype
,
root
.
dtype
.
base_dtype
,
tokens
.
dtype
.
base_dtype
],
'dtype mismatch'
)
root_1xs
=
tf
.
expand_dims
(
root
,
0
)
weights_source_sx1
=
tf
.
expand_dims
(
weights_source
,
1
)
tokens_shape
=
tf
.
shape
(
tokens
)
batch_size
=
tokens_shape
[
0
]
...
...
@@ -200,9 +210,12 @@ def RootPotentialsFromTokens(root, tokens, weights):
# Flatten out the batch dimension so we can use a couple big matmuls.
tokens_bnxt
=
tf
.
reshape
(
tokens
,
[
-
1
,
num_target_activations
])
weights_targets_bnxs
=
tf
.
matmul
(
tokens_bnxt
,
weights
,
transpose_b
=
True
)
weights_targets_bnxs
=
tf
.
matmul
(
tokens_bnxt
,
weights
_arc
,
transpose_b
=
True
)
roots_1xbn
=
tf
.
matmul
(
root_1xs
,
weights_targets_bnxs
,
transpose_b
=
True
)
# Add in the score for selecting the root as a source.
roots_1xbn
+=
tf
.
matmul
(
root_1xs
,
weights_source_sx1
)
# Restore the batch dimension in the output.
roots_bxn
=
tf
.
reshape
(
roots_1xbn
,
[
batch_size
,
num_tokens
])
return
roots_bxn
...
...
@@ -354,3 +367,110 @@ def LabelPotentialsFromTokenPairs(sources, targets, weights):
transpose_b
=
True
)
labels_bxnxl
=
tf
.
squeeze
(
labels_bxnxlx1
,
[
3
])
return
labels_bxnxl
def
ValidArcAndTokenMasks
(
lengths
,
max_length
,
dtype
=
tf
.
float32
):
r
"""Returns 0/1 masks for valid arcs and tokens.
Args:
lengths: [B] vector of input sequence lengths.
max_length: Scalar maximum input sequence length, aka M.
dtype: Data type for output mask.
Returns:
[B,M,M] tensor A with 0/1 indicators of valid arcs. Specifically,
A_{b,t,s} = t,s < lengths[b] ? 1 : 0
[B,M] matrix T with 0/1 indicators of valid tokens. Specifically,
T_{b,t} = t < lengths[b] ? 1 : 0
"""
lengths_bx1
=
tf
.
expand_dims
(
lengths
,
1
)
sequence_m
=
tf
.
range
(
tf
.
cast
(
max_length
,
lengths
.
dtype
.
base_dtype
))
sequence_1xm
=
tf
.
expand_dims
(
sequence_m
,
0
)
# Create vectors of 0/1 indicators for valid tokens. Note that the comparison
# operator will broadcast from [1,M] and [B,1] to [B,M].
valid_token_bxm
=
tf
.
cast
(
sequence_1xm
<
lengths_bx1
,
dtype
)
# Compute matrices of 0/1 indicators for valid arcs as the outer product of
# the valid token indicator vector with itself.
valid_arc_bxmxm
=
tf
.
matmul
(
tf
.
expand_dims
(
valid_token_bxm
,
2
),
tf
.
expand_dims
(
valid_token_bxm
,
1
))
return
valid_arc_bxmxm
,
valid_token_bxm
def
LaplacianMatrix
(
lengths
,
arcs
,
forest
=
False
):
r
"""Returns the (root-augmented) Laplacian matrix for a batch of digraphs.
Args:
lengths: [B] vector of input sequence lengths.
arcs: [B,M,M] tensor of arc potentials where entry b,t,s is the potential of
the arc from s to t in the b'th digraph, while b,t,t is the potential of t
as a root. Entries b,t,s where t or s >= lengths[b] are ignored.
forest: Whether to produce a Laplacian for trees or forests.
Returns:
[B,M,M] tensor L with the Laplacian of each digraph, padded with an identity
matrix. More concretely, the padding entries (t or s >= lengths[b]) are:
L_{b,t,t} = 1.0
L_{b,t,s} = 0.0
Note that this "identity matrix padding" ensures that the determinant of
each padded matrix equals the determinant of the unpadded matrix. The
non-padding entries (t,s < lengths[b]) depend on whether the Laplacian is
constructed for trees or forests. For trees:
L_{b,t,0} = arcs[b,t,t]
L_{b,t,t} = \sum_{s < lengths[b], t != s} arcs[b,t,s]
L_{b,t,s} = -arcs[b,t,s]
For forests:
L_{b,t,t} = \sum_{s < lengths[b]} arcs[b,t,s]
L_{b,t,s} = -arcs[b,t,s]
See http://www.aclweb.org/anthology/D/D07/D07-1015.pdf for details, though
note that our matrices are transposed from their notation.
"""
check
.
Eq
(
arcs
.
get_shape
().
ndims
,
3
,
'arcs must be rank 3'
)
dtype
=
arcs
.
dtype
.
base_dtype
arcs_shape
=
tf
.
shape
(
arcs
)
batch_size
=
arcs_shape
[
0
]
max_length
=
arcs_shape
[
1
]
with
tf
.
control_dependencies
([
tf
.
assert_equal
(
max_length
,
arcs_shape
[
2
])]):
valid_arc_bxmxm
,
valid_token_bxm
=
ValidArcAndTokenMasks
(
lengths
,
max_length
,
dtype
=
dtype
)
invalid_token_bxm
=
tf
.
constant
(
1
,
dtype
=
dtype
)
-
valid_token_bxm
# Zero out all invalid arcs, to avoid polluting bulk summations.
arcs_bxmxm
=
arcs
*
valid_arc_bxmxm
zeros_bxm
=
tf
.
zeros
([
batch_size
,
max_length
],
dtype
)
if
not
forest
:
# For trees, extract the root potentials and exclude them from the sums
# computed below.
roots_bxm
=
tf
.
matrix_diag_part
(
arcs_bxmxm
)
# only defined for trees
arcs_bxmxm
=
tf
.
matrix_set_diag
(
arcs_bxmxm
,
zeros_bxm
)
# Sum inbound arc potentials for each target token. These sums will form
# the diagonal of the Laplacian matrix. Note that these sums are zero for
# invalid tokens, since their arc potentials were masked out above.
sums_bxm
=
tf
.
reduce_sum
(
arcs_bxmxm
,
2
)
if
forest
:
# For forests, zero out the root potentials after computing the sums above
# so we don't cancel them out when we subtract the arc potentials.
arcs_bxmxm
=
tf
.
matrix_set_diag
(
arcs_bxmxm
,
zeros_bxm
)
# The diagonal of the result is the combination of the arc sums, which are
# non-zero only on valid tokens, and the invalid token indicators, which are
# non-zero only on invalid tokens. Note that the latter form the diagonal
# of the identity matrix padding.
diagonal_bxm
=
sums_bxm
+
invalid_token_bxm
# Combine sums and negative arc potentials. Note that the off-diagonal
# padding entries will be zero thanks to the arc mask.
laplacian_bxmxm
=
tf
.
matrix_diag
(
diagonal_bxm
)
-
arcs_bxmxm
if
not
forest
:
# For trees, replace the first column with the root potentials.
roots_bxmx1
=
tf
.
expand_dims
(
roots_bxm
,
2
)
laplacian_bxmxm
=
tf
.
concat
([
roots_bxmx1
,
laplacian_bxmxm
[:,
:,
1
:]],
2
)
return
laplacian_bxmxm
research/syntaxnet/dragnn/python/digraph_ops_test.py
View file @
80178fc6
...
...
@@ -31,16 +31,18 @@ class DigraphOpsTest(tf.test.TestCase):
[
3
,
4
]],
[[
3
,
4
],
[
2
,
3
],
[
1
,
2
]]],
tf
.
float32
)
[
1
,
2
]]],
tf
.
float32
)
# pyformat: disable
target_tokens
=
tf
.
constant
([[[
4
,
5
,
6
],
[
5
,
6
,
7
],
[
6
,
7
,
8
]],
[[
6
,
7
,
8
],
[
5
,
6
,
7
],
[
4
,
5
,
6
]]],
tf
.
float32
)
[
4
,
5
,
6
]]],
tf
.
float32
)
# pyformat: disable
weights
=
tf
.
constant
([[
2
,
3
,
5
],
[
7
,
11
,
13
]],
tf
.
float32
)
tf
.
float32
)
# pyformat: disable
arcs
=
digraph_ops
.
ArcPotentialsFromTokens
(
source_tokens
,
target_tokens
,
weights
)
...
...
@@ -54,7 +56,7 @@ class DigraphOpsTest(tf.test.TestCase):
[
803
,
957
,
1111
]],
[[
1111
,
957
,
803
],
# reflected through the center
[
815
,
702
,
589
],
[
519
,
447
,
375
]]])
[
519
,
447
,
375
]]])
# pyformat: disable
def
testArcSourcePotentialsFromTokens
(
self
):
with
self
.
test_session
():
...
...
@@ -63,7 +65,7 @@ class DigraphOpsTest(tf.test.TestCase):
[
6
,
7
,
8
]],
[[
6
,
7
,
8
],
[
5
,
6
,
7
],
[
4
,
5
,
6
]]],
tf
.
float32
)
[
4
,
5
,
6
]]],
tf
.
float32
)
# pyformat: disable
weights
=
tf
.
constant
([
2
,
3
,
5
],
tf
.
float32
)
arcs
=
digraph_ops
.
ArcSourcePotentialsFromTokens
(
tokens
,
weights
)
...
...
@@ -73,7 +75,7 @@ class DigraphOpsTest(tf.test.TestCase):
[
73
,
73
,
73
]],
[[
73
,
73
,
73
],
[
63
,
63
,
63
],
[
53
,
53
,
53
]]])
[
53
,
53
,
53
]]])
# pyformat: disable
def
testRootPotentialsFromTokens
(
self
):
with
self
.
test_session
():
...
...
@@ -83,15 +85,17 @@ class DigraphOpsTest(tf.test.TestCase):
[
6
,
7
,
8
]],
[[
6
,
7
,
8
],
[
5
,
6
,
7
],
[
4
,
5
,
6
]]],
tf
.
float32
)
weights
=
tf
.
constant
([[
2
,
3
,
5
],
[
7
,
11
,
13
]],
tf
.
float32
)
[
4
,
5
,
6
]]],
tf
.
float32
)
# pyformat: disable
weights_arc
=
tf
.
constant
([[
2
,
3
,
5
],
[
7
,
11
,
13
]],
tf
.
float32
)
# pyformat: disable
weights_source
=
tf
.
constant
([
11
,
10
],
tf
.
float32
)
roots
=
digraph_ops
.
RootPotentialsFromTokens
(
root
,
tokens
,
weights
)
roots
=
digraph_ops
.
RootPotentialsFromTokens
(
root
,
tokens
,
weights_arc
,
weights_source
)
self
.
assertAllEqual
(
roots
.
eval
(),
[[
375
,
4
4
7
,
5
19
],
[
5
19
,
4
47
,
375
]])
self
.
assertAllEqual
(
roots
.
eval
(),
[[
406
,
47
8
,
5
50
],
[
5
50
,
47
8
,
406
]])
# pyformat: disable
def
testCombineArcAndRootPotentials
(
self
):
with
self
.
test_session
():
...
...
@@ -100,9 +104,9 @@ class DigraphOpsTest(tf.test.TestCase):
[
3
,
4
,
5
]],
[[
3
,
4
,
5
],
[
2
,
3
,
4
],
[
1
,
2
,
3
]]],
tf
.
float32
)
[
1
,
2
,
3
]]],
tf
.
float32
)
# pyformat: disable
roots
=
tf
.
constant
([[
6
,
7
,
8
],
[
8
,
7
,
6
]],
tf
.
float32
)
[
8
,
7
,
6
]],
tf
.
float32
)
# pyformat: disable
potentials
=
digraph_ops
.
CombineArcAndRootPotentials
(
arcs
,
roots
)
...
...
@@ -111,7 +115,7 @@ class DigraphOpsTest(tf.test.TestCase):
[
3
,
4
,
8
]],
[[
8
,
4
,
5
],
[
2
,
7
,
4
],
[
1
,
2
,
6
]]])
[
1
,
2
,
6
]]])
# pyformat: disable
def
testLabelPotentialsFromTokens
(
self
):
with
self
.
test_session
():
...
...
@@ -120,12 +124,12 @@ class DigraphOpsTest(tf.test.TestCase):
[
5
,
6
]],
[[
6
,
5
],
[
4
,
3
],
[
2
,
1
]]],
tf
.
float32
)
[
2
,
1
]]],
tf
.
float32
)
# pyformat: disable
weights
=
tf
.
constant
([[
2
,
3
],
[
5
,
7
],
[
11
,
13
]],
tf
.
float32
)
[
11
,
13
]],
tf
.
float32
)
# pyformat: disable
labels
=
digraph_ops
.
LabelPotentialsFromTokens
(
tokens
,
weights
)
...
...
@@ -136,7 +140,7 @@ class DigraphOpsTest(tf.test.TestCase):
[
28
,
67
,
133
]],
[[
27
,
65
,
131
],
[
17
,
41
,
83
],
[
7
,
17
,
35
]]])
[
7
,
17
,
35
]]])
# pyformat: disable
def
testLabelPotentialsFromTokenPairs
(
self
):
with
self
.
test_session
():
...
...
@@ -145,13 +149,13 @@ class DigraphOpsTest(tf.test.TestCase):
[
5
,
6
]],
[[
6
,
5
],
[
4
,
3
],
[
2
,
1
]]],
tf
.
float32
)
[
2
,
1
]]],
tf
.
float32
)
# pyformat: disable
targets
=
tf
.
constant
([[[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
[[
8
,
7
],
[
6
,
5
],
[
4
,
3
]]],
tf
.
float32
)
[
4
,
3
]]],
tf
.
float32
)
# pyformat: disable
weights
=
tf
.
constant
([[[
2
,
3
],
...
...
@@ -159,7 +163,7 @@ class DigraphOpsTest(tf.test.TestCase):
[[
11
,
13
],
[
17
,
19
]],
[[
23
,
29
],
[
31
,
37
]]],
tf
.
float32
)
[
31
,
37
]]],
tf
.
float32
)
# pyformat: disable
labels
=
digraph_ops
.
LabelPotentialsFromTokenPairs
(
sources
,
targets
,
weights
)
...
...
@@ -171,7 +175,114 @@ class DigraphOpsTest(tf.test.TestCase):
[
736
,
2531
,
5043
]],
[[
667
,
2419
,
4857
],
[
303
,
1115
,
2245
],
[
75
,
291
,
593
]]])
[
75
,
291
,
593
]]])
# pyformat: disable
def
testValidArcAndTokenMasks
(
self
):
with
self
.
test_session
():
lengths
=
tf
.
constant
([
1
,
2
,
3
],
tf
.
int64
)
max_length
=
4
valid_arcs
,
valid_tokens
=
digraph_ops
.
ValidArcAndTokenMasks
(
lengths
,
max_length
)
self
.
assertAllEqual
(
valid_arcs
.
eval
(),
[[[
1
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
]],
[[
1
,
1
,
0
,
0
],
[
1
,
1
,
0
,
0
],
[
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
]],
[[
1
,
1
,
1
,
0
],
[
1
,
1
,
1
,
0
],
[
1
,
1
,
1
,
0
],
[
0
,
0
,
0
,
0
]]])
# pyformat: disable
self
.
assertAllEqual
(
valid_tokens
.
eval
(),
[[
1
,
0
,
0
,
0
],
[
1
,
1
,
0
,
0
],
[
1
,
1
,
1
,
0
]])
# pyformat: disable
def
testLaplacianMatrixTree
(
self
):
with
self
.
test_session
():
pad
=
12345.6
arcs
=
tf
.
constant
([[[
2
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
]],
[[
2
,
3
,
pad
,
pad
],
[
5
,
7
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
]],
[[
2
,
3
,
5
,
pad
],
[
7
,
11
,
13
,
pad
],
[
17
,
19
,
23
,
pad
],
[
pad
,
pad
,
pad
,
pad
]],
[[
2
,
3
,
5
,
7
],
[
11
,
13
,
17
,
19
],
[
23
,
29
,
31
,
37
],
[
41
,
43
,
47
,
53
]]],
tf
.
float32
)
# pyformat: disable
lengths
=
tf
.
constant
([
1
,
2
,
3
,
4
],
tf
.
int64
)
laplacian
=
digraph_ops
.
LaplacianMatrix
(
lengths
,
arcs
)
self
.
assertAllEqual
(
laplacian
.
eval
(),
[[[
2
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
1
]],
[[
2
,
-
3
,
0
,
0
],
[
7
,
5
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
1
]],
[[
2
,
-
3
,
-
5
,
0
],
[
11
,
20
,
-
13
,
0
],
[
23
,
-
19
,
36
,
0
],
[
0
,
0
,
0
,
1
]],
[[
2
,
-
3
,
-
5
,
-
7
],
[
13
,
47
,
-
17
,
-
19
],
[
31
,
-
29
,
89
,
-
37
],
[
53
,
-
43
,
-
47
,
131
]]])
# pyformat: disable
def
testLaplacianMatrixForest
(
self
):
with
self
.
test_session
():
pad
=
12345.6
arcs
=
tf
.
constant
([[[
2
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
]],
[[
2
,
3
,
pad
,
pad
],
[
5
,
7
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
],
[
pad
,
pad
,
pad
,
pad
]],
[[
2
,
3
,
5
,
pad
],
[
7
,
11
,
13
,
pad
],
[
17
,
19
,
23
,
pad
],
[
pad
,
pad
,
pad
,
pad
]],
[[
2
,
3
,
5
,
7
],
[
11
,
13
,
17
,
19
],
[
23
,
29
,
31
,
37
],
[
41
,
43
,
47
,
53
]]],
tf
.
float32
)
# pyformat: disable
lengths
=
tf
.
constant
([
1
,
2
,
3
,
4
],
tf
.
int64
)
laplacian
=
digraph_ops
.
LaplacianMatrix
(
lengths
,
arcs
,
forest
=
True
)
self
.
assertAllEqual
(
laplacian
.
eval
(),
[[[
2
,
0
,
0
,
0
],
[
0
,
1
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
1
]],
[[
5
,
-
3
,
0
,
0
],
[
-
5
,
12
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
0
,
0
,
1
]],
[[
10
,
-
3
,
-
5
,
0
],
[
-
7
,
31
,
-
13
,
0
],
[
-
17
,
-
19
,
59
,
0
],
[
0
,
0
,
0
,
1
]],
[[
17
,
-
3
,
-
5
,
-
7
],
[
-
11
,
60
,
-
17
,
-
19
],
[
-
23
,
-
29
,
120
,
-
37
],
[
-
41
,
-
43
,
-
47
,
184
]]])
# pyformat: disable
if
__name__
==
"__main__"
:
...
...
research/syntaxnet/dragnn/python/dragnn_model_saver.py
View file @
80178fc6
...
...
@@ -25,13 +25,14 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
from
google.protobuf
import
text_format
from
dragnn.protos
import
spec_pb2
from
dragnn.python
import
dragnn_model_saver_lib
as
saver_lib
flags
=
tf
.
app
.
flags
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'master_spec'
,
None
,
'Path to task context with '
...
...
@@ -40,10 +41,12 @@ flags.DEFINE_string('params_path', None, 'Path to trained model parameters.')
flags
.
DEFINE_string
(
'export_path'
,
''
,
'Output path for exported servo model.'
)
flags
.
DEFINE_bool
(
'export_moving_averages'
,
False
,
'Whether to export the moving average parameters.'
)
flags
.
DEFINE_bool
(
'build_runtime_graph'
,
False
,
'Whether to build a graph for use by the runtime.'
)
def
export
(
master_spec_path
,
params_path
,
export_path
,
export_moving_averages
):
def
export
(
master_spec_path
,
params_path
,
export_path
,
export_moving_averages
,
build_runtime_graph
):
"""Restores a model and exports it in SavedModel form.
This method loads a graph specified by the spec at master_spec_path and the
...
...
@@ -55,6 +58,7 @@ def export(master_spec_path, params_path, export_path,
params_path: Path to the parameters file to export.
export_path: Path to export the SavedModel to.
export_moving_averages: Whether to export the moving average parameters.
build_runtime_graph: Whether to build a graph for use by the runtime.
"""
graph
=
tf
.
Graph
()
...
...
@@ -70,16 +74,16 @@ def export(master_spec_path, params_path, export_path,
short_to_original
=
saver_lib
.
shorten_resource_paths
(
master_spec
)
saver_lib
.
export_master_spec
(
master_spec
,
graph
)
saver_lib
.
export_to_graph
(
master_spec
,
params_path
,
stripped_path
,
graph
,
export_moving_averages
)
export_moving_averages
,
build_runtime_graph
)
saver_lib
.
export_assets
(
master_spec
,
short_to_original
,
stripped_path
)
def
main
(
unused_argv
):
# Run the exporter.
export
(
FLAGS
.
master_spec
,
FLAGS
.
params_path
,
FLAGS
.
export_
path
,
FLAGS
.
export_moving_averages
)
export
(
FLAGS
.
master_spec
,
FLAGS
.
params_path
,
FLAGS
.
export_path
,
FLAGS
.
export_
moving_averages
,
FLAGS
.
build_runtime_graph
)
tf
.
logging
.
info
(
'Export complete.'
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
app
.
run
(
main
)
research/syntaxnet/dragnn/python/dragnn_model_saver_lib.py
View file @
80178fc6
...
...
@@ -164,6 +164,7 @@ def export_to_graph(master_spec,
export_path
,
external_graph
,
export_moving_averages
,
build_runtime_graph
,
signature_name
=
'model'
):
"""Restores a model and exports it in SavedModel form.
...
...
@@ -177,6 +178,7 @@ def export_to_graph(master_spec,
export_path: Path to export the SavedModel to.
external_graph: A tf.Graph() object to build the graph inside.
export_moving_averages: Whether to export the moving average parameters.
build_runtime_graph: Whether to build a graph for use by the runtime.
signature_name: Name of the signature to insert.
"""
tf
.
logging
.
info
(
...
...
@@ -189,7 +191,7 @@ def export_to_graph(master_spec,
hyperparam_config
.
use_moving_average
=
export_moving_averages
builder
=
graph_builder
.
MasterBuilder
(
master_spec
,
hyperparam_config
)
post_restore_hook
=
builder
.
build_post_restore_hook
()
annotation
=
builder
.
add_annotation
()
annotation
=
builder
.
add_annotation
(
build_runtime_graph
=
build_runtime_graph
)
builder
.
add_saver
()
# Resets session.
...
...
research/syntaxnet/dragnn/python/dragnn_model_saver_lib_test.py
View file @
80178fc6
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for dragnn.python.dragnn_model_saver_lib."""
from
__future__
import
absolute_import
...
...
@@ -26,24 +25,30 @@ import tensorflow as tf
from
google.protobuf
import
text_format
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
googletest
from
dragnn.protos
import
export_pb2
from
dragnn.protos
import
spec_pb2
from
dragnn.python
import
dragnn_model_saver_lib
from
syntaxnet
import
sentence_pb2
from
syntaxnet
import
test_flags
FLAGS
=
tf
.
app
.
flags
.
FLAGS
def
setUpModule
():
if
not
hasattr
(
FLAGS
,
'test_srcdir'
):
FLAGS
.
test_srcdir
=
''
if
not
hasattr
(
FLAGS
,
'test_tmpdir'
):
FLAGS
.
test_tmpdir
=
tf
.
test
.
get_temp_dir
()
_DUMMY_TEST_SENTENCE
=
"""
token {
word: "sentence" start: 0 end: 7 break_level: NO_BREAK
}
token {
word: "0" start: 9 end: 9 break_level: SPACE_BREAK
}
token {
word: "." start: 10 end: 10 break_level: NO_BREAK
}
"""
class
DragnnModelSaverLibTest
(
test_util
.
TensorFlowTestCase
):
def
LoadSpec
(
self
,
spec_path
):
master_spec
=
spec_pb2
.
MasterSpec
()
root_dir
=
os
.
path
.
join
(
FLAGS
.
test_srcdir
,
root_dir
=
os
.
path
.
join
(
test_flags
.
source_root
()
,
'dragnn/python'
)
with
open
(
os
.
path
.
join
(
root_dir
,
'testdata'
,
spec_path
),
'r'
)
as
fin
:
text_format
.
Parse
(
fin
.
read
().
replace
(
'TOPDIR'
,
root_dir
),
master_spec
)
...
...
@@ -52,7 +57,7 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
def
CreateLocalSpec
(
self
,
spec_path
):
master_spec
=
self
.
LoadSpec
(
spec_path
)
master_spec_name
=
os
.
path
.
basename
(
spec_path
)
outfile
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
master_spec_name
)
outfile
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
master_spec_name
)
fout
=
open
(
outfile
,
'w'
)
fout
.
write
(
text_format
.
MessageToString
(
master_spec
))
return
outfile
...
...
@@ -80,16 +85,50 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
# Return a set of all unique paths.
return
set
(
path_list
)
def
GetHookNodeNames
(
self
,
master_spec
):
"""Returns hook node names to use in tests.
Args:
master_spec: MasterSpec proto from which to infer hook node names.
Returns:
Tuple of (averaged hook node name, non-averaged hook node name, cell
subgraph hook node name).
Raises:
ValueError: If hook nodes cannot be inferred from the |master_spec|.
"""
# Find an op name we can use for testing runtime hooks. Assume that at
# least one component has a fixed feature (else what is the model doing?).
component_name
=
None
for
component_spec
in
master_spec
.
component
:
if
component_spec
.
fixed_feature
:
component_name
=
component_spec
.
name
break
if
not
component_name
:
raise
ValueError
(
'Cannot infer hook node names'
)
non_averaged_hook_name
=
'{}/fixed_embedding_matrix_0/trimmed'
.
format
(
component_name
)
averaged_hook_name
=
'{}/ExponentialMovingAverage'
.
format
(
non_averaged_hook_name
)
cell_subgraph_hook_name
=
'{}/EXPORT/CellSubgraphSpec'
.
format
(
component_name
)
return
averaged_hook_name
,
non_averaged_hook_name
,
cell_subgraph_hook_name
def
testModelExport
(
self
):
# Get the master spec and params for this graph.
master_spec
=
self
.
LoadSpec
(
'ud-hungarian.master-spec'
)
params_path
=
os
.
path
.
join
(
FLAGS
.
test_srcdir
,
'dragnn/python/testdata'
test_flags
.
source_root
(),
'dragnn/python/testdata'
'/ud-hungarian.params'
)
# Export the graph via SavedModel. (Here, we maintain a handle to the graph
# for comparison, but that's usually not necessary.)
export_path
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'export'
)
export_path
=
os
.
path
.
join
(
test_flags
.
temp_dir
(),
'export'
)
dragnn_model_saver_lib
.
clean_output_paths
(
export_path
)
saver_graph
=
tf
.
Graph
()
shortened_to_original
=
dragnn_model_saver_lib
.
shorten_resource_paths
(
...
...
@@ -102,7 +141,8 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
params_path
,
export_path
,
saver_graph
,
export_moving_averages
=
False
)
export_moving_averages
=
False
,
build_runtime_graph
=
False
)
# Export the assets as well.
dragnn_model_saver_lib
.
export_assets
(
master_spec
,
shortened_to_original
,
...
...
@@ -126,6 +166,165 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
tf
.
saved_model
.
loader
.
load
(
sess
,
[
tf
.
saved_model
.
tag_constants
.
SERVING
],
export_path
)
averaged_hook_name
,
non_averaged_hook_name
,
_
=
self
.
GetHookNodeNames
(
master_spec
)
# Check that the averaged runtime hook node does not exist.
with
self
.
assertRaises
(
KeyError
):
restored_graph
.
get_operation_by_name
(
averaged_hook_name
)
# Check that the non-averaged version also does not exist.
with
self
.
assertRaises
(
KeyError
):
restored_graph
.
get_operation_by_name
(
non_averaged_hook_name
)
def
testModelExportWithAveragesAndHooks
(
self
):
# Get the master spec and params for this graph.
master_spec
=
self
.
LoadSpec
(
'ud-hungarian.master-spec'
)
params_path
=
os
.
path
.
join
(
test_flags
.
source_root
(),
'dragnn/python/testdata'
'/ud-hungarian.params'
)
# Export the graph via SavedModel. (Here, we maintain a handle to the graph
# for comparison, but that's usually not necessary.) Note that the export
# path must not already exist.
export_path
=
os
.
path
.
join
(
test_flags
.
temp_dir
(),
'export2'
)
dragnn_model_saver_lib
.
clean_output_paths
(
export_path
)
saver_graph
=
tf
.
Graph
()
shortened_to_original
=
dragnn_model_saver_lib
.
shorten_resource_paths
(
master_spec
)
dragnn_model_saver_lib
.
export_master_spec
(
master_spec
,
saver_graph
)
dragnn_model_saver_lib
.
export_to_graph
(
master_spec
,
params_path
,
export_path
,
saver_graph
,
export_moving_averages
=
True
,
build_runtime_graph
=
True
)
# Export the assets as well.
dragnn_model_saver_lib
.
export_assets
(
master_spec
,
shortened_to_original
,
export_path
)
# Validate that the assets are all in the exported directory.
path_set
=
self
.
ValidateAssetExistence
(
master_spec
,
export_path
)
# This master-spec has 4 unique assets. If there are more, we have not
# uniquified the assets properly.
self
.
assertEqual
(
len
(
path_set
),
4
)
# Restore the graph from the checkpoint into a new Graph object.
restored_graph
=
tf
.
Graph
()
restoration_config
=
tf
.
ConfigProto
(
log_device_placement
=
False
,
intra_op_parallelism_threads
=
10
,
inter_op_parallelism_threads
=
10
)
with
tf
.
Session
(
graph
=
restored_graph
,
config
=
restoration_config
)
as
sess
:
tf
.
saved_model
.
loader
.
load
(
sess
,
[
tf
.
saved_model
.
tag_constants
.
SERVING
],
export_path
)
averaged_hook_name
,
non_averaged_hook_name
,
cell_subgraph_hook_name
=
(
self
.
GetHookNodeNames
(
master_spec
))
# Check that an averaged runtime hook node exists.
restored_graph
.
get_operation_by_name
(
averaged_hook_name
)
# Check that the non-averaged version does not exist.
with
self
.
assertRaises
(
KeyError
):
restored_graph
.
get_operation_by_name
(
non_averaged_hook_name
)
# Load the cell subgraph.
cell_subgraph_bytes
=
restored_graph
.
get_tensor_by_name
(
cell_subgraph_hook_name
+
':0'
)
cell_subgraph_bytes
=
cell_subgraph_bytes
.
eval
(
feed_dict
=
{
'annotation/ComputeSession/InputBatch:0'
:
[]})
cell_subgraph_spec
=
export_pb2
.
CellSubgraphSpec
()
cell_subgraph_spec
.
ParseFromString
(
cell_subgraph_bytes
)
tf
.
logging
.
info
(
'cell_subgraph_spec = %s'
,
cell_subgraph_spec
)
# Sanity check inputs.
for
cell_input
in
cell_subgraph_spec
.
input
:
self
.
assertGreater
(
len
(
cell_input
.
name
),
0
)
self
.
assertGreater
(
len
(
cell_input
.
tensor
),
0
)
self
.
assertNotEqual
(
cell_input
.
type
,
export_pb2
.
CellSubgraphSpec
.
Input
.
TYPE_UNKNOWN
)
restored_graph
.
get_tensor_by_name
(
cell_input
.
tensor
)
# shouldn't raise
# Sanity check outputs.
for
cell_output
in
cell_subgraph_spec
.
output
:
self
.
assertGreater
(
len
(
cell_output
.
name
),
0
)
self
.
assertGreater
(
len
(
cell_output
.
tensor
),
0
)
restored_graph
.
get_tensor_by_name
(
cell_output
.
tensor
)
# shouldn't raise
# GetHookNames() finds a component with a fixed feature, so at least the
# first feature ID should exist.
self
.
assertTrue
(
any
(
cell_input
.
name
==
'fixed_channel_0_index_0_ids'
for
cell_input
in
cell_subgraph_spec
.
input
))
# Most dynamic components produce a logits layer.
self
.
assertTrue
(
any
(
cell_output
.
name
==
'logits'
for
cell_output
in
cell_subgraph_spec
.
output
))
def
testModelExportProducesRunnableModel
(
self
):
# Get the master spec and params for this graph.
master_spec
=
self
.
LoadSpec
(
'ud-hungarian.master-spec'
)
params_path
=
os
.
path
.
join
(
test_flags
.
source_root
(),
'dragnn/python/testdata'
'/ud-hungarian.params'
)
# Export the graph via SavedModel. (Here, we maintain a handle to the graph
# for comparison, but that's usually not necessary.)
export_path
=
os
.
path
.
join
(
test_flags
.
temp_dir
(),
'export'
)
dragnn_model_saver_lib
.
clean_output_paths
(
export_path
)
saver_graph
=
tf
.
Graph
()
shortened_to_original
=
dragnn_model_saver_lib
.
shorten_resource_paths
(
master_spec
)
dragnn_model_saver_lib
.
export_master_spec
(
master_spec
,
saver_graph
)
dragnn_model_saver_lib
.
export_to_graph
(
master_spec
,
params_path
,
export_path
,
saver_graph
,
export_moving_averages
=
False
,
build_runtime_graph
=
False
)
# Export the assets as well.
dragnn_model_saver_lib
.
export_assets
(
master_spec
,
shortened_to_original
,
export_path
)
# Restore the graph from the checkpoint into a new Graph object.
restored_graph
=
tf
.
Graph
()
restoration_config
=
tf
.
ConfigProto
(
log_device_placement
=
False
,
intra_op_parallelism_threads
=
10
,
inter_op_parallelism_threads
=
10
)
with
tf
.
Session
(
graph
=
restored_graph
,
config
=
restoration_config
)
as
sess
:
tf
.
saved_model
.
loader
.
load
(
sess
,
[
tf
.
saved_model
.
tag_constants
.
SERVING
],
export_path
)
test_doc
=
sentence_pb2
.
Sentence
()
text_format
.
Parse
(
_DUMMY_TEST_SENTENCE
,
test_doc
)
test_reader_string
=
test_doc
.
SerializeToString
()
test_inputs
=
[
test_reader_string
]
tf_out
=
sess
.
run
(
'annotation/annotations:0'
,
feed_dict
=
{
'annotation/ComputeSession/InputBatch:0'
:
test_inputs
})
# We don't care about accuracy, only that the run sessions don't crash.
del
tf_out
if
__name__
==
'__main__'
:
googletest
.
main
()
research/syntaxnet/dragnn/python/file_diff_test.py
0 → 100644
View file @
80178fc6
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Diff test that compares two files are identical."""
from
absl
import
flags
import
tensorflow
as
tf
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'actual_file'
,
None
,
'File to test.'
)
flags
.
DEFINE_string
(
'expected_file'
,
None
,
'File with expected contents.'
)
class
DiffTest
(
tf
.
test
.
TestCase
):
def
testEqualFiles
(
self
):
content_actual
=
None
content_expected
=
None
try
:
with
open
(
FLAGS
.
actual_file
)
as
actual
:
content_actual
=
actual
.
read
()
except
IOError
as
e
:
self
.
fail
(
"Error opening '%s': %s"
%
(
FLAGS
.
actual_file
,
e
.
strerror
))
try
:
with
open
(
FLAGS
.
expected_file
)
as
expected
:
content_expected
=
expected
.
read
()
except
IOError
as
e
:
self
.
fail
(
"Error opening '%s': %s"
%
(
FLAGS
.
expected_file
,
e
.
strerror
))
self
.
assertTrue
(
content_actual
==
content_expected
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/syntaxnet/dragnn/python/graph_builder.py
View file @
80178fc6
...
...
@@ -28,7 +28,7 @@ from syntaxnet.util import check
try
:
tf
.
NotDifferentiable
(
'ExtractFixedFeatures'
)
except
KeyError
as
e
:
except
KeyError
,
e
:
logging
.
info
(
str
(
e
))
...
...
@@ -179,6 +179,8 @@ class MasterBuilder(object):
optimizer: handle to the tf.train Optimizer object used to train this model.
master_vars: dictionary of globally shared tf.Variable objects (e.g.
the global training step and learning rate.)
read_from_avg: Whether to use averaged params instead of normal params.
build_runtime_graph: Whether to build a graph for use by the runtime.
"""
def
__init__
(
self
,
master_spec
,
hyperparam_config
=
None
,
pool_scope
=
'shared'
):
...
...
@@ -197,14 +199,15 @@ class MasterBuilder(object):
ValueError: if a component is not found in the registry.
"""
self
.
spec
=
master_spec
self
.
hyperparams
=
(
spec_pb2
.
GridPoint
()
if
hyperparam_config
is
None
else
hyperparam_config
)
self
.
hyperparams
=
(
spec_pb2
.
GridPoint
()
if
hyperparam_config
is
None
else
hyperparam_config
)
_validate_grid_point
(
self
.
hyperparams
)
self
.
pool_scope
=
pool_scope
# Set the graph-level random seed before creating the Components so the ops
# they create will use this seed.
tf
.
set_random_seed
(
hyperparam
_config
.
seed
)
tf
.
set_random_seed
(
self
.
hyperparam
s
.
seed
)
# Construct all utility class and variables for each Component.
self
.
components
=
[]
...
...
@@ -219,19 +222,37 @@ class MasterBuilder(object):
self
.
lookup_component
[
comp
.
name
]
=
comp
self
.
components
.
append
(
comp
)
# Add global step variable.
self
.
master_vars
=
{}
with
tf
.
variable_scope
(
'master'
,
reuse
=
False
):
# Add global step variable.
self
.
master_vars
[
'step'
]
=
tf
.
get_variable
(
'step'
,
[],
initializer
=
tf
.
zeros_initializer
(),
dtype
=
tf
.
int32
)
self
.
master_vars
[
'learning_rate'
]
=
_create_learning_rate
(
self
.
hyperparams
,
self
.
master_vars
[
'step'
])
# Add learning rate. If the learning rate is optimized externally, then
# just create an assign op.
if
self
.
hyperparams
.
pbt_optimize_learning_rate
:
self
.
master_vars
[
'learning_rate'
]
=
tf
.
get_variable
(
'learning_rate'
,
initializer
=
tf
.
constant
(
self
.
hyperparams
.
learning_rate
,
dtype
=
tf
.
float32
))
lr_assign_input
=
tf
.
placeholder
(
tf
.
float32
,
[],
'pbt/assign/learning_rate/Value'
)
tf
.
assign
(
self
.
master_vars
[
'learning_rate'
],
value
=
lr_assign_input
,
name
=
'pbt/assign/learning_rate'
)
else
:
self
.
master_vars
[
'learning_rate'
]
=
_create_learning_rate
(
self
.
hyperparams
,
self
.
master_vars
[
'step'
])
# Construct optimizer.
self
.
optimizer
=
_create_optimizer
(
self
.
hyperparams
,
self
.
master_vars
[
'learning_rate'
],
self
.
master_vars
[
'step'
])
self
.
read_from_avg
=
False
self
.
build_runtime_graph
=
False
@
property
def
component_names
(
self
):
return
tuple
(
c
.
name
for
c
in
self
.
components
)
...
...
@@ -366,8 +387,9 @@ class MasterBuilder(object):
max_index
=
len
(
self
.
components
)
else
:
if
not
0
<
max_index
<=
len
(
self
.
components
):
raise
IndexError
(
'Invalid max_index {} for components {}; handle {}'
.
format
(
max_index
,
self
.
component_names
,
handle
.
name
))
raise
IndexError
(
'Invalid max_index {} for components {}; handle {}'
.
format
(
max_index
,
self
.
component_names
,
handle
.
name
))
# By default, we train every component supervised.
if
not
component_weights
:
...
...
@@ -375,6 +397,11 @@ class MasterBuilder(object):
if
not
unroll_using_oracle
:
unroll_using_oracle
=
[
True
]
*
max_index
if
not
max_index
<=
len
(
unroll_using_oracle
):
raise
IndexError
((
'Invalid max_index {} for unroll_using_oracle {}; '
'handle {}'
).
format
(
max_index
,
unroll_using_oracle
,
handle
.
name
))
component_weights
=
component_weights
[:
max_index
]
total_weight
=
(
float
)(
sum
(
component_weights
))
component_weights
=
[
w
/
total_weight
for
w
in
component_weights
]
...
...
@@ -408,10 +435,10 @@ class MasterBuilder(object):
args
=
(
master_state
,
network_states
)
if
unroll_using_oracle
[
component_index
]:
handle
,
component_cost
,
component_correct
,
component_total
=
(
tf
.
cond
(
comp
.
training_beam_size
>
1
,
lambda
:
comp
.
build_structured_training
(
*
args
),
lambda
:
comp
.
build_greedy_training
(
*
args
)))
handle
,
component_cost
,
component_correct
,
component_total
=
(
tf
.
cond
(
comp
.
training_beam_size
>
1
,
lambda
:
comp
.
build_structured_training
(
*
args
),
lambda
:
comp
.
build_greedy_training
(
*
args
)))
else
:
handle
=
comp
.
build_greedy_inference
(
*
args
,
during_training
=
True
)
...
...
@@ -445,6 +472,7 @@ class MasterBuilder(object):
# 1. compute the gradients,
# 2. add an optimizer to update the parameters using the gradients,
# 3. make the ComputeSession handle depend on the optimizer.
gradient_norm
=
tf
.
constant
(
0.
)
if
compute_gradients
:
logging
.
info
(
'Creating train op with %d variables:
\n\t
%s'
,
len
(
params_to_train
),
...
...
@@ -452,8 +480,11 @@ class MasterBuilder(object):
grads_and_vars
=
self
.
optimizer
.
compute_gradients
(
cost
,
var_list
=
params_to_train
)
clipped_gradients
=
[(
self
.
_clip_gradients
(
g
),
v
)
for
g
,
v
in
grads_and_vars
]
clipped_gradients
=
[
(
self
.
_clip_gradients
(
g
),
v
)
for
g
,
v
in
grads_and_vars
]
gradient_norm
=
tf
.
global_norm
(
list
(
zip
(
*
clipped_gradients
))[
0
])
minimize_op
=
self
.
optimizer
.
apply_gradients
(
clipped_gradients
,
global_step
=
self
.
master_vars
[
'step'
])
...
...
@@ -474,6 +505,7 @@ class MasterBuilder(object):
# Returns named access to common outputs.
outputs
=
{
'cost'
:
cost
,
'gradient_norm'
:
gradient_norm
,
'batch'
:
effective_batch
,
'metrics'
:
metrics
,
}
...
...
@@ -520,7 +552,10 @@ class MasterBuilder(object):
with
tf
.
control_dependencies
(
control_ops
):
return
tf
.
no_op
(
name
=
'post_restore_hook_master'
)
def
build_inference
(
self
,
handle
,
use_moving_average
=
False
):
def
build_inference
(
self
,
handle
,
use_moving_average
=
False
,
build_runtime_graph
=
False
):
"""Builds an inference pipeline.
This always uses the whole pipeline.
...
...
@@ -530,25 +565,30 @@ class MasterBuilder(object):
use_moving_average: Whether or not to read from the moving
average variables instead of the true parameters. Note: it is not
possible to make gradient updates when this is True.
build_runtime_graph: Whether to build a graph for use by the runtime.
Returns:
handle: Handle after annotation.
"""
self
.
read_from_avg
=
use_moving_average
self
.
build_runtime_graph
=
build_runtime_graph
network_states
=
{}
for
comp
in
self
.
components
:
network_states
[
comp
.
name
]
=
component
.
NetworkState
()
handle
=
dragnn_ops
.
init_component_data
(
handle
,
beam_size
=
comp
.
inference_beam_size
,
component
=
comp
.
name
)
master_state
=
component
.
MasterState
(
handle
,
dragnn_ops
.
batch_size
(
handle
,
component
=
comp
.
name
))
if
build_runtime_graph
:
batch_size
=
1
# runtime uses singleton batches
else
:
batch_size
=
dragnn_ops
.
batch_size
(
handle
,
component
=
comp
.
name
)
master_state
=
component
.
MasterState
(
handle
,
batch_size
)
with
tf
.
control_dependencies
([
handle
]):
handle
=
comp
.
build_greedy_inference
(
master_state
,
network_states
)
handle
=
dragnn_ops
.
write_annotations
(
handle
,
component
=
comp
.
name
)
self
.
read_from_avg
=
False
self
.
build_runtime_graph
=
False
return
handle
def
add_training_from_config
(
self
,
...
...
@@ -625,7 +665,10 @@ class MasterBuilder(object):
return
self
.
_outputs_with_release
(
handle
,
{
'input_batch'
:
input_batch
},
outputs
)
def
add_annotation
(
self
,
name_scope
=
'annotation'
,
enable_tracing
=
False
):
def
add_annotation
(
self
,
name_scope
=
'annotation'
,
enable_tracing
=
False
,
build_runtime_graph
=
False
):
"""Adds an annotation pipeline to the graph.
This will create the following additional named targets by default, for use
...
...
@@ -640,13 +683,17 @@ class MasterBuilder(object):
enable_tracing: Enabling this will result in two things:
1. Tracing will be enabled during inference.
2. A 'traces' node will be added to the outputs.
build_runtime_graph: Whether to build a graph for use by the runtime.
Returns:
A dictionary of input and output nodes.
"""
with
tf
.
name_scope
(
name_scope
):
handle
,
input_batch
=
self
.
_get_session_with_reader
(
enable_tracing
)
handle
=
self
.
build_inference
(
handle
,
use_moving_average
=
True
)
handle
=
self
.
build_inference
(
handle
,
use_moving_average
=
True
,
build_runtime_graph
=
build_runtime_graph
)
annotations
=
dragnn_ops
.
emit_annotations
(
handle
,
component
=
self
.
spec
.
component
[
-
1
].
name
)
...
...
@@ -666,7 +713,7 @@ class MasterBuilder(object):
def
add_saver
(
self
):
"""Adds a Saver for all variables in the graph."""
logging
.
info
(
'
Saving
variables:
\n\t
%s'
,
logging
.
info
(
'
Generating op to save
variables:
\n\t
%s'
,
'
\n\t
'
.
join
([
x
.
name
for
x
in
tf
.
global_variables
()]))
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
[
x
for
x
in
tf
.
global_variables
()],
...
...
research/syntaxnet/dragnn/python/graph_builder_test.py
View file @
80178fc6
...
...
@@ -20,7 +20,6 @@ import os.path
import
numpy
as
np
from
six.moves
import
xrange
import
tensorflow
as
tf
from
google.protobuf
import
text_format
...
...
@@ -30,13 +29,12 @@ from dragnn.protos import trace_pb2
from
dragnn.python
import
dragnn_ops
from
dragnn.python
import
graph_builder
from
syntaxnet
import
sentence_pb2
from
syntaxnet
import
test_flags
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
googletest
from
tensorflow.python.platform
import
tf_logging
as
logging
FLAGS
=
tf
.
app
.
flags
.
FLAGS
_DUMMY_GOLD_SENTENCE
=
"""
token {
...
...
@@ -151,13 +149,6 @@ token {
]
def
setUpModule
():
if
not
hasattr
(
FLAGS
,
'test_srcdir'
):
FLAGS
.
test_srcdir
=
''
if
not
hasattr
(
FLAGS
,
'test_tmpdir'
):
FLAGS
.
test_tmpdir
=
tf
.
test
.
get_temp_dir
()
def
_as_op
(
x
):
"""Always returns the tf.Operation associated with a node."""
return
x
.
op
if
isinstance
(
x
,
tf
.
Tensor
)
else
x
...
...
@@ -244,7 +235,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
def
LoadSpec
(
self
,
spec_path
):
master_spec
=
spec_pb2
.
MasterSpec
()
testdata
=
os
.
path
.
join
(
FLAGS
.
test_srcdir
,
testdata
=
os
.
path
.
join
(
test_flags
.
source_root
()
,
'dragnn/core/testdata'
)
with
open
(
os
.
path
.
join
(
testdata
,
spec_path
),
'r'
)
as
fin
:
text_format
.
Parse
(
fin
.
read
().
replace
(
'TESTDATA'
,
testdata
),
master_spec
)
...
...
@@ -445,7 +436,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
self
.
assertEqual
(
expected_num_actions
,
correct_val
)
self
.
assertEqual
(
expected_num_actions
,
total_val
)
builder
.
saver
.
save
(
sess
,
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'model'
))
builder
.
saver
.
save
(
sess
,
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'model'
))
logging
.
info
(
'Running test.'
)
logging
.
info
(
'Printing annotations'
)
...
...
research/syntaxnet/dragnn/python/lexicon_test.py
View file @
80178fc6
...
...
@@ -27,8 +27,7 @@ from dragnn.python import lexicon
from
syntaxnet
import
parser_trainer
from
syntaxnet
import
task_spec_pb2
FLAGS
=
tf
.
app
.
flags
.
FLAGS
from
syntaxnet
import
test_flags
_EXPECTED_CONTEXT
=
r
"""
...
...
@@ -46,13 +45,6 @@ input { name: "known-word-map" Part { file_pattern: "/tmp/known-word-map" } }
"""
def
setUpModule
():
if
not
hasattr
(
FLAGS
,
'test_srcdir'
):
FLAGS
.
test_srcdir
=
''
if
not
hasattr
(
FLAGS
,
'test_tmpdir'
):
FLAGS
.
test_tmpdir
=
tf
.
test
.
get_temp_dir
()
class
LexiconTest
(
tf
.
test
.
TestCase
):
def
testCreateLexiconContext
(
self
):
...
...
@@ -62,8 +54,8 @@ class LexiconTest(tf.test.TestCase):
lexicon
.
create_lexicon_context
(
'/tmp'
),
expected_context
)
def
testBuildLexicon
(
self
):
empty_input_path
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'empty-input'
)
lexicon_output_path
=
os
.
path
.
join
(
FLAGS
.
test_t
mpdir
,
'lexicon-output'
)
empty_input_path
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'empty-input'
)
lexicon_output_path
=
os
.
path
.
join
(
test_flags
.
te
mp
_
dir
()
,
'lexicon-output'
)
with
open
(
empty_input_path
,
'w'
):
pass
...
...
research/syntaxnet/dragnn/python/load_mst_cc_impl.py
0 → 100644
View file @
80178fc6
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads mst_ops shared library."""
import
os.path
import
tensorflow
as
tf
tf
.
load_op_library
(
os
.
path
.
join
(
tf
.
resource_loader
.
get_data_files_path
(),
'mst_cc_impl.so'
))
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment