Unverified Commit 80178fc6 authored by Mark Omernick's avatar Mark Omernick Committed by GitHub
Browse files

Merge pull request #4153 from terryykoo/master

Export @195097388.
parents a84e1ef9 edea2b67
...@@ -16,7 +16,7 @@ message MasterPerformanceSettings { ...@@ -16,7 +16,7 @@ message MasterPerformanceSettings {
// Maximum size of the free list in the SessionStatePool. NB: The default // Maximum size of the free list in the SessionStatePool. NB: The default
// value may occasionally change. // value may occasionally change.
optional uint64 session_state_pool_max_free_states = 1 [default = 4]; optional uint64 session_state_pool_max_free_states = 1 [default = 16];
} }
// As above, but for component-specific performance tuning settings. // As above, but for component-specific performance tuning settings.
......
// DRAGNN Configuration proto. See go/dragnn-design for more information. // DRAGNN Configuration proto.
syntax = "proto2"; syntax = "proto2";
...@@ -93,7 +94,7 @@ message Part { ...@@ -93,7 +94,7 @@ message Part {
// are extracted, embedded, and then concatenated together as a group. // are extracted, embedded, and then concatenated together as a group.
// Specification for a feature channel that is a *fixed* function of the input. // Specification for a feature channel that is a *fixed* function of the input.
// NEXT_ID: 10 // NEXT_ID: 12
message FixedFeatureChannel { message FixedFeatureChannel {
// Interpretable name for this feature channel. NN builders might depend on // Interpretable name for this feature channel. NN builders might depend on
// this to determine how to hook different channels up internally. // this to determine how to hook different channels up internally.
...@@ -129,6 +130,19 @@ message FixedFeatureChannel { ...@@ -129,6 +130,19 @@ message FixedFeatureChannel {
// Vocab file, containing all vocabulary words one per line. // Vocab file, containing all vocabulary words one per line.
optional Resource vocab = 8; optional Resource vocab = 8;
// Settings for feature ID dropout:
// If non-negative, enables feature ID dropout, and dropped feature IDs will
// be replaced with this ID.
optional int64 dropout_id = 10 [default = -1];
// Probability of keeping each of the |vocabulary_size| feature IDs. Only
// used if |dropout_id| is non-negative, and must not be empty in that case.
// If this has fewer than |vocabulary_size| values, then the final value is
// tiled onto the remaining IDs. For example, specifying a single value is
// equivalent to setting all IDs to that value.
repeated float dropout_keep_probability = 11 [packed = true];
} }
// Specification for a feature channel that *links* to component // Specification for a feature channel that *links* to component
...@@ -173,11 +187,17 @@ message TrainingGridSpec { ...@@ -173,11 +187,17 @@ message TrainingGridSpec {
} }
// A hyperparameter configuration for a training run. // A hyperparameter configuration for a training run.
// NEXT ID: 22 // NEXT ID: 23
message GridPoint { message GridPoint {
// Global learning rate initialization point. // Global learning rate initialization point.
optional double learning_rate = 1 [default = 0.1]; optional double learning_rate = 1 [default = 0.1];
// Whether to use PBT (population-based training) to optimize the learning
// rate. Population-based training is not currently open-source, so this will
// just create a tf.assign op which external frameworks can use to adjust the
// learning rate.
optional bool pbt_optimize_learning_rate = 22 [default = false];
// Momentum coefficient when using MomentumOptimizer. // Momentum coefficient when using MomentumOptimizer.
optional double momentum = 2 [default = 0.9]; optional double momentum = 2 [default = 0.9];
......
...@@ -53,6 +53,8 @@ message ComponentStepTrace { ...@@ -53,6 +53,8 @@ message ComponentStepTrace {
// Set to true once the step is finished. (This allows us to open a step after // Set to true once the step is finished. (This allows us to open a step after
// each transition, without having to know if it will be used.) // each transition, without having to know if it will be used.)
optional bool step_finished = 6 [default = false]; optional bool step_finished = 6 [default = false];
extensions 1000 to max;
} }
// The traces for all steps for a single Component. // The traces for all steps for a single Component.
......
...@@ -16,6 +16,17 @@ cc_binary( ...@@ -16,6 +16,17 @@ cc_binary(
], ],
) )
cc_binary(
name = "mst_cc_impl.so",
linkopts = select({
"//conditions:default": ["-lm"],
"@org_tensorflow//tensorflow:darwin": [],
}),
linkshared = 1,
linkstatic = 1,
deps = ["//dragnn/mst:mst_ops_cc"],
)
filegroup( filegroup(
name = "testdata", name = "testdata",
data = glob(["testdata/**"]), data = glob(["testdata/**"]),
...@@ -27,6 +38,12 @@ py_library( ...@@ -27,6 +38,12 @@ py_library(
data = [":dragnn_cc_impl.so"], data = [":dragnn_cc_impl.so"],
) )
py_library(
name = "load_mst_cc_impl_py",
srcs = ["load_mst_cc_impl.py"],
data = [":mst_cc_impl.so"],
)
py_library( py_library(
name = "bulk_component", name = "bulk_component",
srcs = [ srcs = [
...@@ -50,6 +67,8 @@ py_library( ...@@ -50,6 +67,8 @@ py_library(
":bulk_component", ":bulk_component",
":dragnn_ops", ":dragnn_ops",
":network_units", ":network_units",
":runtime_support",
"//dragnn/protos:export_pb2_py",
"//syntaxnet/util:check", "//syntaxnet/util:check",
"//syntaxnet/util:pyregistry", "//syntaxnet/util:pyregistry",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
...@@ -85,9 +104,9 @@ py_library( ...@@ -85,9 +104,9 @@ py_library(
":graph_builder", ":graph_builder",
":load_dragnn_cc_impl_py", ":load_dragnn_cc_impl_py",
":network_units", ":network_units",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
"//syntaxnet:load_parser_ops_py", "//syntaxnet:load_parser_ops_py",
"//syntaxnet:sentence_py_pb2", "//syntaxnet:sentence_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py", "@org_tensorflow//tensorflow/core:protos_all_py",
], ],
...@@ -99,7 +118,9 @@ py_test( ...@@ -99,7 +118,9 @@ py_test(
data = [":testdata"], data = [":testdata"],
deps = [ deps = [
":dragnn_model_saver_lib", ":dragnn_model_saver_lib",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:export_pb2_py",
"//dragnn/protos:spec_pb2_py",
"//syntaxnet:test_flags",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
], ],
) )
...@@ -110,7 +131,9 @@ py_binary( ...@@ -110,7 +131,9 @@ py_binary(
deps = [ deps = [
":dragnn_model_saver_lib", ":dragnn_model_saver_lib",
":spec_builder", ":spec_builder",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
"@absl_py//absl:app",
"@absl_py//absl/flags",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py", "@org_tensorflow//tensorflow/core:protos_all_py",
], ],
...@@ -127,7 +150,7 @@ py_library( ...@@ -127,7 +150,7 @@ py_library(
":network_units", ":network_units",
":transformer_units", ":transformer_units",
":wrapped_units", ":wrapped_units",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
"//syntaxnet/util:check", "//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py", "@org_tensorflow//tensorflow/core:protos_all_py",
...@@ -159,7 +182,7 @@ py_test( ...@@ -159,7 +182,7 @@ py_test(
srcs = ["render_parse_tree_graphviz_test.py"], srcs = ["render_parse_tree_graphviz_test.py"],
deps = [ deps = [
":render_parse_tree_graphviz", ":render_parse_tree_graphviz",
"//syntaxnet:sentence_py_pb2", "//syntaxnet:sentence_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
], ],
) )
...@@ -168,7 +191,7 @@ py_library( ...@@ -168,7 +191,7 @@ py_library(
name = "render_spec_with_graphviz", name = "render_spec_with_graphviz",
srcs = ["render_spec_with_graphviz.py"], srcs = ["render_spec_with_graphviz.py"],
deps = [ deps = [
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
], ],
) )
...@@ -197,7 +220,7 @@ py_binary( ...@@ -197,7 +220,7 @@ py_binary(
"//dragnn/viz:viz-min-js-gz", "//dragnn/viz:viz-min-js-gz",
], ],
deps = [ deps = [
"//dragnn/protos:trace_py_pb2", "//dragnn/protos:trace_pb2_py",
], ],
) )
...@@ -206,8 +229,8 @@ py_test( ...@@ -206,8 +229,8 @@ py_test(
srcs = ["visualization_test.py"], srcs = ["visualization_test.py"],
deps = [ deps = [
":visualization", ":visualization",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
"//dragnn/protos:trace_py_pb2", "//dragnn/protos:trace_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
], ],
) )
...@@ -225,6 +248,18 @@ py_library( ...@@ -225,6 +248,18 @@ py_library(
# Tests # Tests
py_test(
name = "component_test",
srcs = [
"component_test.py",
],
deps = [
":components",
"//dragnn/protos:spec_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_test( py_test(
name = "bulk_component_test", name = "bulk_component_test",
srcs = [ srcs = [
...@@ -235,9 +270,9 @@ py_test( ...@@ -235,9 +270,9 @@ py_test(
":components", ":components",
":dragnn_ops", ":dragnn_ops",
":network_units", ":network_units",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
"//syntaxnet:load_parser_ops_py", "//syntaxnet:load_parser_ops_py",
"//syntaxnet:sentence_py_pb2", "//syntaxnet:sentence_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py", "@org_tensorflow//tensorflow/core:protos_all_py",
], ],
...@@ -270,10 +305,11 @@ py_test( ...@@ -270,10 +305,11 @@ py_test(
deps = [ deps = [
":dragnn_ops", ":dragnn_ops",
":graph_builder", ":graph_builder",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
"//dragnn/protos:trace_py_pb2", "//dragnn/protos:trace_pb2_py",
"//syntaxnet:load_parser_ops_py", "//syntaxnet:load_parser_ops_py",
"//syntaxnet:sentence_py_pb2", "//syntaxnet:sentence_pb2_py",
"//syntaxnet:test_flags",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py", "@org_tensorflow//tensorflow/core:protos_all_py",
], ],
...@@ -287,7 +323,7 @@ py_test( ...@@ -287,7 +323,7 @@ py_test(
":network_units", ":network_units",
"//dragnn/core:dragnn_bulk_ops", "//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops", "//dragnn/core:dragnn_ops",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
"//syntaxnet:load_parser_ops_py", "//syntaxnet:load_parser_ops_py",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py", "@org_tensorflow//tensorflow/core:protos_all_py",
...@@ -303,7 +339,8 @@ py_test( ...@@ -303,7 +339,8 @@ py_test(
":sentence_io", ":sentence_io",
"//syntaxnet:load_parser_ops_py", "//syntaxnet:load_parser_ops_py",
"//syntaxnet:parser_ops", "//syntaxnet:parser_ops",
"//syntaxnet:sentence_py_pb2", "//syntaxnet:sentence_pb2_py",
"//syntaxnet:test_flags",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py", "@org_tensorflow//tensorflow/core:protos_all_py",
], ],
...@@ -313,21 +350,31 @@ py_library( ...@@ -313,21 +350,31 @@ py_library(
name = "trainer_lib", name = "trainer_lib",
srcs = ["trainer_lib.py"], srcs = ["trainer_lib.py"],
deps = [ deps = [
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
"//syntaxnet:parser_ops", "//syntaxnet:parser_ops",
"//syntaxnet:sentence_py_pb2", "//syntaxnet:sentence_pb2_py",
"//syntaxnet:task_spec_py_pb2", "//syntaxnet:task_spec_pb2_py",
"//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py", "@org_tensorflow//tensorflow/core:protos_all_py",
], ],
) )
py_test(
name = "trainer_lib_test",
srcs = ["trainer_lib_test.py"],
deps = [
":trainer_lib",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_library( py_library(
name = "lexicon", name = "lexicon",
srcs = ["lexicon.py"], srcs = ["lexicon.py"],
deps = [ deps = [
"//syntaxnet:parser_ops", "//syntaxnet:parser_ops",
"//syntaxnet:task_spec_py_pb2", "//syntaxnet:task_spec_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
], ],
) )
...@@ -340,6 +387,7 @@ py_test( ...@@ -340,6 +387,7 @@ py_test(
"//syntaxnet:load_parser_ops_py", "//syntaxnet:load_parser_ops_py",
"//syntaxnet:parser_ops", "//syntaxnet:parser_ops",
"//syntaxnet:parser_trainer", "//syntaxnet:parser_trainer",
"//syntaxnet:test_flags",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
], ],
) )
...@@ -348,7 +396,7 @@ py_library( ...@@ -348,7 +396,7 @@ py_library(
name = "evaluation", name = "evaluation",
srcs = ["evaluation.py"], srcs = ["evaluation.py"],
deps = [ deps = [
"//syntaxnet:sentence_py_pb2", "//syntaxnet:sentence_pb2_py",
"//syntaxnet/util:check", "//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
], ],
...@@ -359,7 +407,7 @@ py_test( ...@@ -359,7 +407,7 @@ py_test(
srcs = ["evaluation_test.py"], srcs = ["evaluation_test.py"],
deps = [ deps = [
":evaluation", ":evaluation",
"//syntaxnet:sentence_py_pb2", "//syntaxnet:sentence_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
], ],
) )
...@@ -369,7 +417,7 @@ py_library( ...@@ -369,7 +417,7 @@ py_library(
srcs = ["spec_builder.py"], srcs = ["spec_builder.py"],
deps = [ deps = [
":lexicon", ":lexicon",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
"//syntaxnet:parser_ops", "//syntaxnet:parser_ops",
"//syntaxnet/util:check", "//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
...@@ -381,7 +429,7 @@ py_test( ...@@ -381,7 +429,7 @@ py_test(
srcs = ["spec_builder_test.py"], srcs = ["spec_builder_test.py"],
deps = [ deps = [
":spec_builder", ":spec_builder",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
"//syntaxnet:load_parser_ops_py", "//syntaxnet:load_parser_ops_py",
"//syntaxnet:parser_ops", "//syntaxnet:parser_ops",
"//syntaxnet:parser_trainer", "//syntaxnet:parser_trainer",
...@@ -418,6 +466,17 @@ py_library( ...@@ -418,6 +466,17 @@ py_library(
], ],
) )
py_test(
name = "biaffine_units_test",
srcs = ["biaffine_units_test.py"],
deps = [
":biaffine_units",
":network_units",
"//dragnn/protos:spec_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_library( py_library(
name = "transformer_units", name = "transformer_units",
srcs = ["transformer_units.py"], srcs = ["transformer_units.py"],
...@@ -437,10 +496,85 @@ py_test( ...@@ -437,10 +496,85 @@ py_test(
":transformer_units", ":transformer_units",
"//dragnn/core:dragnn_bulk_ops", "//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops", "//dragnn/core:dragnn_ops",
"//dragnn/protos:spec_py_pb2", "//dragnn/protos:spec_pb2_py",
"//syntaxnet:load_parser_ops_py", "//syntaxnet:load_parser_ops_py",
"@org_tensorflow//tensorflow:tensorflow_py", "@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py", "@org_tensorflow//tensorflow/core:protos_all_py",
], ],
) )
py_library(
name = "runtime_support",
srcs = ["runtime_support.py"],
deps = [
":network_units",
"//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_test(
name = "runtime_support_test",
srcs = ["runtime_support_test.py"],
deps = [
":network_units",
":runtime_support",
"//dragnn/protos:export_pb2_py",
"//dragnn/protos:spec_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_library(
name = "file_diff_test",
srcs = ["file_diff_test.py"],
deps = [
"@absl_py//absl/flags",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_library(
name = "mst_ops",
srcs = ["mst_ops.py"],
visibility = ["//visibility:public"],
deps = [
":digraph_ops",
":load_mst_cc_impl_py",
"//dragnn/mst:mst_ops",
"//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_test(
name = "mst_ops_test",
srcs = ["mst_ops_test.py"],
deps = [
":mst_ops",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_library(
name = "mst_units",
srcs = ["mst_units.py"],
deps = [
":mst_ops",
":network_units",
"//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_test(
name = "mst_units_test",
size = "small",
srcs = ["mst_units_test.py"],
deps = [
":mst_units",
":network_units",
"//dragnn/protos:spec_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
...@@ -79,24 +79,44 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface): ...@@ -79,24 +79,44 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
self._source_dim = self._linked_feature_dims['sources'] self._source_dim = self._linked_feature_dims['sources']
self._target_dim = self._linked_feature_dims['targets'] self._target_dim = self._linked_feature_dims['targets']
# TODO(googleuser): Make parameter initialization configurable.
self._weights = [] self._weights = []
self._weights.append(tf.get_variable( self._weights.append(
'weights_arc', [self._source_dim, self._target_dim], tf.float32, tf.get_variable('weights_arc', [self._source_dim, self._target_dim],
tf.random_normal_initializer(stddev=1e-4))) tf.float32, tf.orthogonal_initializer()))
self._weights.append(tf.get_variable( self._weights.append(
'weights_source', [self._source_dim], tf.float32, tf.get_variable('weights_source', [self._source_dim], tf.float32,
tf.random_normal_initializer(stddev=1e-4))) tf.zeros_initializer()))
self._weights.append(tf.get_variable( self._weights.append(
'root', [self._source_dim], tf.float32, tf.get_variable('root', [self._source_dim], tf.float32,
tf.random_normal_initializer(stddev=1e-4))) tf.zeros_initializer()))
self._params.extend(self._weights) self._params.extend(self._weights)
self._regularized_weights.extend(self._weights) self._regularized_weights.extend(self._weights)
# Add runtime hooks for pre-computed weights.
self._derived_params.append(self._get_root_weights)
self._derived_params.append(self._get_root_bias)
# Negative Layer.dim indicates that the dimension is dynamic. # Negative Layer.dim indicates that the dimension is dynamic.
self._layers.append(network_units.Layer(component, 'adjacency', -1)) self._layers.append(network_units.Layer(component, 'adjacency', -1))
def _get_root_weights(self):
"""Pre-computes the product of the root embedding and arc weights."""
weights_arc = self._component.get_variable('weights_arc')
root = self._component.get_variable('root')
name = self._component.name + '/root_weights'
with tf.name_scope(None):
return tf.matmul(tf.expand_dims(root, 0), weights_arc, name=name)
def _get_root_bias(self):
"""Pre-computes the product of the root embedding and source weights."""
weights_source = self._component.get_variable('weights_source')
root = self._component.get_variable('root')
name = self._component.name + '/root_bias'
with tf.name_scope(None):
return tf.matmul(
tf.expand_dims(root, 0), tf.expand_dims(weights_source, 1), name=name)
def create(self, def create(self,
fixed_embeddings, fixed_embeddings,
linked_embeddings, linked_embeddings,
...@@ -133,12 +153,17 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface): ...@@ -133,12 +153,17 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
sources_bxnxn = digraph_ops.ArcSourcePotentialsFromTokens( sources_bxnxn = digraph_ops.ArcSourcePotentialsFromTokens(
source_tokens_bxnxs, weights_source) source_tokens_bxnxs, weights_source)
roots_bxn = digraph_ops.RootPotentialsFromTokens( roots_bxn = digraph_ops.RootPotentialsFromTokens(
root, target_tokens_bxnxt, weights_arc) root, target_tokens_bxnxt, weights_arc, weights_source)
# Combine them into a single matrix with the roots on the diagonal. # Combine them into a single matrix with the roots on the diagonal.
adjacency_bxnxn = digraph_ops.CombineArcAndRootPotentials( adjacency_bxnxn = digraph_ops.CombineArcAndRootPotentials(
arcs_bxnxn + sources_bxnxn, roots_bxn) arcs_bxnxn + sources_bxnxn, roots_bxn)
# The adjacency matrix currently has sources on rows and targets on columns,
# but we want targets on rows so that maximizing within a row corresponds to
# selecting sources for a given target.
adjacency_bxnxn = tf.matrix_transpose(adjacency_bxnxn)
return [tf.reshape(adjacency_bxnxn, [-1, num_tokens])] return [tf.reshape(adjacency_bxnxn, [-1, num_tokens])]
......
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for biaffine_units."""
import tensorflow as tf
from google.protobuf import text_format
from dragnn.protos import spec_pb2
from dragnn.python import biaffine_units
from dragnn.python import network_units
_BATCH_SIZE = 11
_NUM_TOKENS = 22
_TOKEN_DIM = 33
class MockNetwork(object):
def __init__(self):
pass
def get_layer_size(self, unused_name):
return _TOKEN_DIM
class MockComponent(object):
def __init__(self, master, component_spec):
self.master = master
self.spec = component_spec
self.name = component_spec.name
self.network = MockNetwork()
self.beam_size = 1
self.num_actions = 45
self._attrs = {}
def attr(self, name):
return self._attrs[name]
def get_variable(self, name):
return tf.get_variable(name)
class MockMaster(object):
def __init__(self):
self.spec = spec_pb2.MasterSpec()
self.hyperparams = spec_pb2.GridPoint()
self.lookup_component = {
'previous': MockComponent(self, spec_pb2.ComponentSpec())
}
def _make_biaffine_spec():
"""Returns a ComponentSpec that the BiaffineDigraphNetwork works on."""
component_spec = spec_pb2.ComponentSpec()
text_format.Parse("""
name: "test_component"
backend { registered_name: "TestComponent" }
linked_feature {
name: "sources"
fml: "input.focus"
source_translator: "identity"
source_component: "previous"
source_layer: "sources"
size: 1
embedding_dim: -1
}
linked_feature {
name: "targets"
fml: "input.focus"
source_translator: "identity"
source_component: "previous"
source_layer: "targets"
size: 1
embedding_dim: -1
}
network_unit {
registered_name: "biaffine_units.BiaffineDigraphNetwork"
}
""", component_spec)
return component_spec
class BiaffineDigraphNetworkTest(tf.test.TestCase):
def setUp(self):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf.reset_default_graph()
def testCanCreate(self):
"""Tests that create() works on a good spec."""
with tf.Graph().as_default(), self.test_session():
master = MockMaster()
component = MockComponent(master, _make_biaffine_spec())
with tf.variable_scope(component.name, reuse=None):
component.network = biaffine_units.BiaffineDigraphNetwork(component)
with tf.variable_scope(component.name, reuse=True):
sources = network_units.NamedTensor(
tf.zeros([_BATCH_SIZE * _NUM_TOKENS, _TOKEN_DIM]), 'sources')
targets = network_units.NamedTensor(
tf.zeros([_BATCH_SIZE * _NUM_TOKENS, _TOKEN_DIM]), 'targets')
# No assertions on the result, just don't crash.
component.network.create(
fixed_embeddings=[],
linked_embeddings=[sources, targets],
context_tensor_arrays=None,
attention_tensor=None,
during_training=True,
stride=_BATCH_SIZE)
def testDerivedParametersForRuntime(self):
"""Test generation of derived parameters for the runtime."""
with tf.Graph().as_default(), self.test_session():
master = MockMaster()
component = MockComponent(master, _make_biaffine_spec())
with tf.variable_scope(component.name, reuse=None):
component.network = biaffine_units.BiaffineDigraphNetwork(component)
with tf.variable_scope(component.name, reuse=True):
self.assertEqual(len(component.network.derived_params), 2)
root_weights = component.network.derived_params[0]()
root_bias = component.network.derived_params[1]()
# Only check shape; values are random due to initialization.
self.assertAllEqual(root_weights.shape.as_list(), [1, _TOKEN_DIM])
self.assertAllEqual(root_bias.shape.as_list(), [1, 1])
if __name__ == '__main__':
tf.test.main()
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Component builders for non-recurrent networks in DRAGNN.""" """Component builders for non-recurrent networks in DRAGNN."""
...@@ -51,10 +50,8 @@ def fetch_linked_embedding(comp, network_states, feature_spec): ...@@ -51,10 +50,8 @@ def fetch_linked_embedding(comp, network_states, feature_spec):
feature_spec.name) feature_spec.name)
source = comp.master.lookup_component[feature_spec.source_component] source = comp.master.lookup_component[feature_spec.source_component]
return network_units.NamedTensor( return network_units.NamedTensor(network_states[source.name].activations[
network_states[source.name].activations[ feature_spec.source_layer].bulk_tensor, feature_spec.name)
feature_spec.source_layer].bulk_tensor,
feature_spec.name)
def _validate_embedded_fixed_features(comp): def _validate_embedded_fixed_features(comp):
...@@ -63,17 +60,20 @@ def _validate_embedded_fixed_features(comp): ...@@ -63,17 +60,20 @@ def _validate_embedded_fixed_features(comp):
check.Gt(feature.embedding_dim, 0, check.Gt(feature.embedding_dim, 0,
'Embeddings requested for non-embedded feature: %s' % feature) 'Embeddings requested for non-embedded feature: %s' % feature)
if feature.is_constant: if feature.is_constant:
check.IsTrue(feature.HasField('pretrained_embedding_matrix'), check.IsTrue(
'Constant embeddings must be pretrained: %s' % feature) feature.HasField('pretrained_embedding_matrix'),
'Constant embeddings must be pretrained: %s' % feature)
def fetch_differentiable_fixed_embeddings(comp, state, stride): def fetch_differentiable_fixed_embeddings(comp, state, stride, during_training):
"""Looks up fixed features with separate, differentiable, embedding lookup. """Looks up fixed features with separate, differentiable, embedding lookup.
Args: Args:
comp: Component whose fixed features we wish to look up. comp: Component whose fixed features we wish to look up.
state: live MasterState object for the component. state: live MasterState object for the component.
stride: Tensor containing current batch * beam size. stride: Tensor containing current batch * beam size.
during_training: True if this is being called from a training code path.
This controls, e.g., the use of feature ID dropout.
Returns: Returns:
state handle: updated state handle to be used after this call state handle: updated state handle to be used after this call
...@@ -93,6 +93,11 @@ def fetch_differentiable_fixed_embeddings(comp, state, stride): ...@@ -93,6 +93,11 @@ def fetch_differentiable_fixed_embeddings(comp, state, stride):
'differentiable') 'differentiable')
tf.logging.info('[%s] Adding %s fixed feature "%s"', comp.name, tf.logging.info('[%s] Adding %s fixed feature "%s"', comp.name,
differentiable_or_constant, feature_spec.name) differentiable_or_constant, feature_spec.name)
if during_training and feature_spec.dropout_id >= 0:
ids[channel], weights[channel] = network_units.apply_feature_id_dropout(
ids[channel], weights[channel], feature_spec)
size = stride * num_steps * feature_spec.size size = stride * num_steps * feature_spec.size
fixed_embedding = network_units.embedding_lookup( fixed_embedding = network_units.embedding_lookup(
comp.get_variable(network_units.fixed_embeddings_name(channel)), comp.get_variable(network_units.fixed_embeddings_name(channel)),
...@@ -105,16 +110,22 @@ def fetch_differentiable_fixed_embeddings(comp, state, stride): ...@@ -105,16 +110,22 @@ def fetch_differentiable_fixed_embeddings(comp, state, stride):
return state.handle, fixed_embeddings return state.handle, fixed_embeddings
def fetch_fast_fixed_embeddings(comp, state): def fetch_fast_fixed_embeddings(comp,
state,
pad_to_batch=None,
pad_to_steps=None):
"""Looks up fixed features with fast, non-differentiable, op. """Looks up fixed features with fast, non-differentiable, op.
Since BulkFixedEmbeddings is non-differentiable with respect to the Since BulkFixedEmbeddings is non-differentiable with respect to the
embeddings, the idea is to call this function only when the graph is embeddings, the idea is to call this function only when the graph is
not being used for training. not being used for training. If the function is being called with fixed step
and batch sizes, it will use the most efficient possible extractor.
Args: Args:
comp: Component whose fixed features we wish to look up. comp: Component whose fixed features we wish to look up.
state: live MasterState object for the component. state: live MasterState object for the component.
pad_to_batch: Optional; the number of batch elements to pad to.
pad_to_steps: Optional; the number of steps to pad to.
Returns: Returns:
state handle: updated state handle to be used after this call state handle: updated state handle to be used after this call
...@@ -126,19 +137,50 @@ def fetch_fast_fixed_embeddings(comp, state): ...@@ -126,19 +137,50 @@ def fetch_fast_fixed_embeddings(comp, state):
return state.handle, [] return state.handle, []
tf.logging.info('[%s] Adding %d fast fixed features', comp.name, num_channels) tf.logging.info('[%s] Adding %d fast fixed features', comp.name, num_channels)
state.handle, bulk_embeddings, _ = dragnn_ops.bulk_fixed_embeddings( features = [
state.handle, [ comp.get_variable(network_units.fixed_embeddings_name(c))
comp.get_variable(network_units.fixed_embeddings_name(c)) for c in range(num_channels)
for c in range(num_channels) ]
],
component=comp.name) if pad_to_batch is not None and pad_to_steps is not None:
# If we have fixed padding numbers, we can use 'bulk_embed_fixed_features',
bulk_embeddings = network_units.NamedTensor(bulk_embeddings, # which is the fastest embedding extractor.
'bulk-%s-fixed-features' % state.handle, bulk_embeddings, _ = dragnn_ops.bulk_embed_fixed_features(
comp.name) state.handle,
features,
component=comp.name,
pad_to_batch=pad_to_batch,
pad_to_steps=pad_to_steps)
else:
state.handle, bulk_embeddings, _ = dragnn_ops.bulk_fixed_embeddings(
state.handle, features, component=comp.name)
bulk_embeddings = network_units.NamedTensor(
bulk_embeddings, 'bulk-%s-fixed-features' % comp.name)
return state.handle, [bulk_embeddings] return state.handle, [bulk_embeddings]
def fetch_dense_ragged_embeddings(comp, state):
"""Gets embeddings in RaggedTensor format."""
_validate_embedded_fixed_features(comp)
num_channels = len(comp.spec.fixed_feature)
if not num_channels:
return state.handle, []
tf.logging.info('[%s] Adding %d fast fixed features', comp.name, num_channels)
features = [
comp.get_variable(network_units.fixed_embeddings_name(c))
for c in range(num_channels)
]
state.handle, data, offsets = dragnn_ops.bulk_embed_dense_fixed_features(
state.handle, features, component=comp.name)
data = network_units.NamedTensor(data, 'dense-%s-data' % comp.name)
offsets = network_units.NamedTensor(offsets, 'dense-%s-offsets' % comp.name)
return state.handle, [data, offsets]
def extract_fixed_feature_ids(comp, state, stride): def extract_fixed_feature_ids(comp, state, stride):
"""Extracts fixed feature IDs. """Extracts fixed feature IDs.
...@@ -194,8 +236,10 @@ def update_network_states(comp, tensors, network_states, stride): ...@@ -194,8 +236,10 @@ def update_network_states(comp, tensors, network_states, stride):
with tf.name_scope(comp.name + '/stored_act'): with tf.name_scope(comp.name + '/stored_act'):
for index, network_tensor in enumerate(tensors): for index, network_tensor in enumerate(tensors):
network_state.activations[comp.network.layers[index].name] = ( network_state.activations[comp.network.layers[index].name] = (
network_units.StoredActivations(tensor=network_tensor, stride=stride, network_units.StoredActivations(
dim=comp.network.layers[index].dim)) tensor=network_tensor,
stride=stride,
dim=comp.network.layers[index].dim))
def build_cross_entropy_loss(logits, gold): def build_cross_entropy_loss(logits, gold):
...@@ -205,7 +249,7 @@ def build_cross_entropy_loss(logits, gold): ...@@ -205,7 +249,7 @@ def build_cross_entropy_loss(logits, gold):
Args: Args:
logits: float Tensor of scores. logits: float Tensor of scores.
gold: int Tensor of one-hot labels. gold: int Tensor of gold label ids.
Returns: Returns:
cost, correct, total: the total cost, the total number of correctly cost, correct, total: the total cost, the total number of correctly
...@@ -251,9 +295,10 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase): ...@@ -251,9 +295,10 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
""" """
logging.info('Building component: %s', self.spec.name) logging.info('Building component: %s', self.spec.name)
stride = state.current_batch_size * self.training_beam_size stride = state.current_batch_size * self.training_beam_size
self.network.pre_create(stride)
with tf.variable_scope(self.name, reuse=True): with tf.variable_scope(self.name, reuse=True):
state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings( state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings(
self, state, stride) self, state, stride, True)
linked_embeddings = [ linked_embeddings = [
fetch_linked_embedding(self, network_states, spec) fetch_linked_embedding(self, network_states, spec)
...@@ -307,14 +352,29 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase): ...@@ -307,14 +352,29 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
stride = state.current_batch_size * self.training_beam_size stride = state.current_batch_size * self.training_beam_size
else: else:
stride = state.current_batch_size * self.inference_beam_size stride = state.current_batch_size * self.inference_beam_size
self.network.pre_create(stride)
with tf.variable_scope(self.name, reuse=True): with tf.variable_scope(self.name, reuse=True):
if during_training: if during_training:
state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings( state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings(
self, state, stride) self, state, stride, during_training)
else: else:
state.handle, fixed_embeddings = fetch_fast_fixed_embeddings(self, if 'use_densors' in self.spec.network_unit.parameters:
state) state.handle, fixed_embeddings = fetch_dense_ragged_embeddings(
self, state)
else:
if ('padded_batch_size' in self.spec.network_unit.parameters and
'padded_sentence_length' in self.spec.network_unit.parameters):
state.handle, fixed_embeddings = fetch_fast_fixed_embeddings(
self,
state,
pad_to_batch=-1,
pad_to_steps=int(self.spec.network_unit.parameters[
'padded_sentence_length']))
else:
state.handle, fixed_embeddings = fetch_fast_fixed_embeddings(
self, state)
linked_embeddings = [ linked_embeddings = [
fetch_linked_embedding(self, network_states, spec) fetch_linked_embedding(self, network_states, spec)
...@@ -331,6 +391,7 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase): ...@@ -331,6 +391,7 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
stride=stride) stride=stride)
update_network_states(self, tensors, network_states, stride) update_network_states(self, tensors, network_states, stride)
self._add_runtime_hooks()
return state.handle return state.handle
...@@ -367,7 +428,9 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase): ...@@ -367,7 +428,9 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
def build_greedy_inference(self, state, network_states, def build_greedy_inference(self, state, network_states,
during_training=False): during_training=False):
"""See base class.""" """See base class."""
return self._extract_feature_ids(state, network_states, during_training) handle = self._extract_feature_ids(state, network_states, during_training)
self._add_runtime_hooks()
return handle
def _extract_feature_ids(self, state, network_states, during_training): def _extract_feature_ids(self, state, network_states, during_training):
"""Extracts feature IDs and advances a batch using the oracle path. """Extracts feature IDs and advances a batch using the oracle path.
...@@ -387,6 +450,7 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase): ...@@ -387,6 +450,7 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
stride = state.current_batch_size * self.training_beam_size stride = state.current_batch_size * self.training_beam_size
else: else:
stride = state.current_batch_size * self.inference_beam_size stride = state.current_batch_size * self.inference_beam_size
self.network.pre_create(stride)
with tf.variable_scope(self.name, reuse=True): with tf.variable_scope(self.name, reuse=True):
state.handle, ids = extract_fixed_feature_ids(self, state, stride) state.handle, ids = extract_fixed_feature_ids(self, state, stride)
...@@ -438,17 +502,21 @@ class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase): ...@@ -438,17 +502,21 @@ class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
] ]
stride = state.current_batch_size * self.training_beam_size stride = state.current_batch_size * self.training_beam_size
self.network.pre_create(stride)
with tf.variable_scope(self.name, reuse=True): with tf.variable_scope(self.name, reuse=True):
network_tensors = self.network.create([], linked_embeddings, None, None, network_tensors = self.network.create([], linked_embeddings, None, None,
True, stride) True, stride)
update_network_states(self, network_tensors, network_states, stride) update_network_states(self, network_tensors, network_states, stride)
logits = self.network.get_logits(network_tensors)
state.handle, gold = dragnn_ops.bulk_advance_from_oracle( state.handle, gold = dragnn_ops.bulk_advance_from_oracle(
state.handle, component=self.name) state.handle, component=self.name)
cost, correct, total = self.network.compute_bulk_loss(
cost, correct, total = build_cross_entropy_loss(logits, gold) stride, network_tensors, gold)
if cost is None:
# The network does not have a custom bulk loss; default to softmax.
logits = self.network.get_logits(network_tensors)
cost, correct, total = build_cross_entropy_loss(logits, gold)
cost = self.add_regularizer(cost) cost = self.add_regularizer(cost)
return state.handle, cost, correct, total return state.handle, cost, correct, total
...@@ -483,13 +551,24 @@ class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase): ...@@ -483,13 +551,24 @@ class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
stride = state.current_batch_size * self.training_beam_size stride = state.current_batch_size * self.training_beam_size
else: else:
stride = state.current_batch_size * self.inference_beam_size stride = state.current_batch_size * self.inference_beam_size
self.network.pre_create(stride)
with tf.variable_scope(self.name, reuse=True): with tf.variable_scope(self.name, reuse=True):
network_tensors = self.network.create( network_tensors = self.network.create([], linked_embeddings, None, None,
[], linked_embeddings, None, None, during_training, stride) during_training, stride)
update_network_states(self, network_tensors, network_states, stride) update_network_states(self, network_tensors, network_states, stride)
logits = self.network.get_logits(network_tensors) logits = self.network.get_bulk_predictions(stride, network_tensors)
return dragnn_ops.bulk_advance_from_prediction( if logits is None:
# The network does not produce custom bulk predictions; default to logits.
logits = self.network.get_logits(network_tensors)
logits = tf.cond(self.locally_normalize,
lambda: tf.nn.log_softmax(logits), lambda: logits)
if self._output_as_probabilities:
logits = tf.nn.softmax(logits)
handle = dragnn_ops.bulk_advance_from_prediction(
state.handle, logits, component=self.name) state.handle, logits, component=self.name)
self._add_runtime_hooks()
return handle
...@@ -41,8 +41,6 @@ from dragnn.python import dragnn_ops ...@@ -41,8 +41,6 @@ from dragnn.python import dragnn_ops
from dragnn.python import network_units from dragnn.python import network_units
from syntaxnet import sentence_pb2 from syntaxnet import sentence_pb2
FLAGS = tf.app.flags.FLAGS
class MockNetworkUnit(object): class MockNetworkUnit(object):
...@@ -63,6 +61,7 @@ class MockMaster(object): ...@@ -63,6 +61,7 @@ class MockMaster(object):
self.spec = spec_pb2.MasterSpec() self.spec = spec_pb2.MasterSpec()
self.hyperparams = spec_pb2.GridPoint() self.hyperparams = spec_pb2.GridPoint()
self.lookup_component = {'mock': MockComponent()} self.lookup_component = {'mock': MockComponent()}
self.build_runtime_graph = False
def _create_fake_corpus(): def _create_fake_corpus():
...@@ -84,9 +83,12 @@ def _create_fake_corpus(): ...@@ -84,9 +83,12 @@ def _create_fake_corpus():
class BulkComponentTest(test_util.TensorFlowTestCase): class BulkComponentTest(test_util.TensorFlowTestCase):
def setUp(self): def setUp(self):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf.reset_default_graph()
self.master = MockMaster() self.master = MockMaster()
self.master_state = component.MasterState( self.master_state = component.MasterState(
handle='handle', current_batch_size=2) handle=tf.constant(['foo', 'bar']), current_batch_size=2)
self.network_states = { self.network_states = {
'mock': component.NetworkState(), 'mock': component.NetworkState(),
'test': component.NetworkState(), 'test': component.NetworkState(),
...@@ -107,22 +109,21 @@ class BulkComponentTest(test_util.TensorFlowTestCase): ...@@ -107,22 +109,21 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
""", component_spec) """, component_spec)
# For feature extraction: # For feature extraction:
with tf.Graph().as_default(): comp = bulk_component.BulkFeatureExtractorComponentBuilder(
comp = bulk_component.BulkFeatureExtractorComponentBuilder( self.master, component_spec)
self.master, component_spec)
# Expect feature extraction to generate a error due to the "history" # Expect feature extraction to generate a error due to the "history"
# translator. # translator.
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
comp.build_greedy_training(self.master_state, self.network_states) comp.build_greedy_training(self.master_state, self.network_states)
# As well as annotation: # As well as annotation:
with tf.Graph().as_default(): self.setUp()
comp = bulk_component.BulkAnnotatorComponentBuilder( comp = bulk_component.BulkAnnotatorComponentBuilder(self.master,
self.master, component_spec) component_spec)
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
comp.build_greedy_training(self.master_state, self.network_states) comp.build_greedy_training(self.master_state, self.network_states)
def testFailsOnRecurrentLinkedFeature(self): def testFailsOnRecurrentLinkedFeature(self):
component_spec = spec_pb2.ComponentSpec() component_spec = spec_pb2.ComponentSpec()
...@@ -143,22 +144,21 @@ class BulkComponentTest(test_util.TensorFlowTestCase): ...@@ -143,22 +144,21 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
""", component_spec) """, component_spec)
# For feature extraction: # For feature extraction:
with tf.Graph().as_default(): comp = bulk_component.BulkFeatureExtractorComponentBuilder(
comp = bulk_component.BulkFeatureExtractorComponentBuilder( self.master, component_spec)
self.master, component_spec)
# Expect feature extraction to generate a error due to the "history" # Expect feature extraction to generate a error due to the "history"
# translator. # translator.
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
comp.build_greedy_training(self.master_state, self.network_states) comp.build_greedy_training(self.master_state, self.network_states)
# As well as annotation: # As well as annotation:
with tf.Graph().as_default(): self.setUp()
comp = bulk_component.BulkAnnotatorComponentBuilder( comp = bulk_component.BulkAnnotatorComponentBuilder(self.master,
self.master, component_spec) component_spec)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
comp.build_greedy_training(self.master_state, self.network_states) comp.build_greedy_training(self.master_state, self.network_states)
def testConstantFixedFeatureFailsIfNotPretrained(self): def testConstantFixedFeatureFailsIfNotPretrained(self):
component_spec = spec_pb2.ComponentSpec() component_spec = spec_pb2.ComponentSpec()
...@@ -175,21 +175,20 @@ class BulkComponentTest(test_util.TensorFlowTestCase): ...@@ -175,21 +175,20 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder" registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
} }
""", component_spec) """, component_spec)
with tf.Graph().as_default(): comp = bulk_component.BulkFeatureExtractorComponentBuilder(
comp = bulk_component.BulkFeatureExtractorComponentBuilder( self.master, component_spec)
self.master, component_spec)
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(ValueError,
'Constant embeddings must be pretrained'): 'Constant embeddings must be pretrained'):
comp.build_greedy_training(self.master_state, self.network_states) comp.build_greedy_training(self.master_state, self.network_states)
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(ValueError,
'Constant embeddings must be pretrained'): 'Constant embeddings must be pretrained'):
comp.build_greedy_inference( comp.build_greedy_inference(
self.master_state, self.network_states, during_training=True) self.master_state, self.network_states, during_training=True)
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(ValueError,
'Constant embeddings must be pretrained'): 'Constant embeddings must be pretrained'):
comp.build_greedy_inference( comp.build_greedy_inference(
self.master_state, self.network_states, during_training=False) self.master_state, self.network_states, during_training=False)
def testNormalFixedFeaturesAreDifferentiable(self): def testNormalFixedFeaturesAreDifferentiable(self):
component_spec = spec_pb2.ComponentSpec() component_spec = spec_pb2.ComponentSpec()
...@@ -207,25 +206,24 @@ class BulkComponentTest(test_util.TensorFlowTestCase): ...@@ -207,25 +206,24 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder" registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
} }
""", component_spec) """, component_spec)
with tf.Graph().as_default(): comp = bulk_component.BulkFeatureExtractorComponentBuilder(
comp = bulk_component.BulkFeatureExtractorComponentBuilder( self.master, component_spec)
self.master, component_spec)
# Get embedding matrix variables. # Get embedding matrix variables.
with tf.variable_scope(comp.name, reuse=True): with tf.variable_scope(comp.name, reuse=True):
fixed_embedding_matrix = tf.get_variable( fixed_embedding_matrix = tf.get_variable(
network_units.fixed_embeddings_name(0)) network_units.fixed_embeddings_name(0))
# Get output layer. # Get output layer.
comp.build_greedy_training(self.master_state, self.network_states) comp.build_greedy_training(self.master_state, self.network_states)
activations = self.network_states[comp.name].activations activations = self.network_states[comp.name].activations
outputs = activations[comp.network.layers[0].name].bulk_tensor outputs = activations[comp.network.layers[0].name].bulk_tensor
# Compute the gradient of the output layer w.r.t. the embedding matrix. # Compute the gradient of the output layer w.r.t. the embedding matrix.
# This should be well-defined for in the normal case. # This should be well-defined for in the normal case.
gradients = tf.gradients(outputs, fixed_embedding_matrix) gradients = tf.gradients(outputs, fixed_embedding_matrix)
self.assertEqual(len(gradients), 1) self.assertEqual(len(gradients), 1)
self.assertFalse(gradients[0] is None) self.assertFalse(gradients[0] is None)
def testConstantFixedFeaturesAreNotDifferentiableButOthersAre(self): def testConstantFixedFeaturesAreNotDifferentiableButOthersAre(self):
component_spec = spec_pb2.ComponentSpec() component_spec = spec_pb2.ComponentSpec()
...@@ -249,31 +247,30 @@ class BulkComponentTest(test_util.TensorFlowTestCase): ...@@ -249,31 +247,30 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder" registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
} }
""", component_spec) """, component_spec)
with tf.Graph().as_default(): comp = bulk_component.BulkFeatureExtractorComponentBuilder(
comp = bulk_component.BulkFeatureExtractorComponentBuilder( self.master, component_spec)
self.master, component_spec)
# Get embedding matrix variables.
# Get embedding matrix variables. with tf.variable_scope(comp.name, reuse=True):
with tf.variable_scope(comp.name, reuse=True): constant_embedding_matrix = tf.get_variable(
constant_embedding_matrix = tf.get_variable( network_units.fixed_embeddings_name(0))
network_units.fixed_embeddings_name(0)) trainable_embedding_matrix = tf.get_variable(
trainable_embedding_matrix = tf.get_variable( network_units.fixed_embeddings_name(1))
network_units.fixed_embeddings_name(1))
# Get output layer.
# Get output layer. comp.build_greedy_training(self.master_state, self.network_states)
comp.build_greedy_training(self.master_state, self.network_states) activations = self.network_states[comp.name].activations
activations = self.network_states[comp.name].activations outputs = activations[comp.network.layers[0].name].bulk_tensor
outputs = activations[comp.network.layers[0].name].bulk_tensor
# The constant embeddings are non-differentiable.
# The constant embeddings are non-differentiable. constant_gradients = tf.gradients(outputs, constant_embedding_matrix)
constant_gradients = tf.gradients(outputs, constant_embedding_matrix) self.assertEqual(len(constant_gradients), 1)
self.assertEqual(len(constant_gradients), 1) self.assertTrue(constant_gradients[0] is None)
self.assertTrue(constant_gradients[0] is None)
# The trainable embeddings are differentiable.
# The trainable embeddings are differentiable. trainable_gradients = tf.gradients(outputs, trainable_embedding_matrix)
trainable_gradients = tf.gradients(outputs, trainable_embedding_matrix) self.assertEqual(len(trainable_gradients), 1)
self.assertEqual(len(trainable_gradients), 1) self.assertFalse(trainable_gradients[0] is None)
self.assertFalse(trainable_gradients[0] is None)
def testFailsOnFixedFeature(self): def testFailsOnFixedFeature(self):
component_spec = spec_pb2.ComponentSpec() component_spec = spec_pb2.ComponentSpec()
...@@ -306,15 +303,14 @@ class BulkComponentTest(test_util.TensorFlowTestCase): ...@@ -306,15 +303,14 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
name: "fixed" embedding_dim: -1 size: 1 name: "fixed" embedding_dim: -1 size: 1
} }
""", component_spec) """, component_spec)
with tf.Graph().as_default(): comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
comp = bulk_component.BulkFeatureIdExtractorComponentBuilder( self.master, component_spec)
self.master, component_spec)
# Should not raise errors. # Should not raise errors.
self.network_states[component_spec.name] = component.NetworkState() self.network_states[component_spec.name] = component.NetworkState()
comp.build_greedy_training(self.master_state, self.network_states) comp.build_greedy_training(self.master_state, self.network_states)
self.network_states[component_spec.name] = component.NetworkState() self.network_states[component_spec.name] = component.NetworkState()
comp.build_greedy_inference(self.master_state, self.network_states) comp.build_greedy_inference(self.master_state, self.network_states)
def testBulkFeatureIdExtractorFailsOnLinkedFeature(self): def testBulkFeatureIdExtractorFailsOnLinkedFeature(self):
component_spec = spec_pb2.ComponentSpec() component_spec = spec_pb2.ComponentSpec()
...@@ -332,10 +328,9 @@ class BulkComponentTest(test_util.TensorFlowTestCase): ...@@ -332,10 +328,9 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
source_component: "mock" source_component: "mock"
} }
""", component_spec) """, component_spec)
with tf.Graph().as_default(): with self.assertRaises(ValueError):
with self.assertRaises(ValueError): unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder( self.master, component_spec)
self.master, component_spec)
def testBulkFeatureIdExtractorOkWithMultipleFixedFeatures(self): def testBulkFeatureIdExtractorOkWithMultipleFixedFeatures(self):
component_spec = spec_pb2.ComponentSpec() component_spec = spec_pb2.ComponentSpec()
...@@ -354,15 +349,14 @@ class BulkComponentTest(test_util.TensorFlowTestCase): ...@@ -354,15 +349,14 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
name: "fixed3" embedding_dim: -1 size: 1 name: "fixed3" embedding_dim: -1 size: 1
} }
""", component_spec) """, component_spec)
with tf.Graph().as_default(): comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
comp = bulk_component.BulkFeatureIdExtractorComponentBuilder( self.master, component_spec)
self.master, component_spec)
# Should not raise errors. # Should not raise errors.
self.network_states[component_spec.name] = component.NetworkState() self.network_states[component_spec.name] = component.NetworkState()
comp.build_greedy_training(self.master_state, self.network_states) comp.build_greedy_training(self.master_state, self.network_states)
self.network_states[component_spec.name] = component.NetworkState() self.network_states[component_spec.name] = component.NetworkState()
comp.build_greedy_inference(self.master_state, self.network_states) comp.build_greedy_inference(self.master_state, self.network_states)
def testBulkFeatureIdExtractorFailsOnEmbeddedFixedFeature(self): def testBulkFeatureIdExtractorFailsOnEmbeddedFixedFeature(self):
component_spec = spec_pb2.ComponentSpec() component_spec = spec_pb2.ComponentSpec()
...@@ -375,10 +369,9 @@ class BulkComponentTest(test_util.TensorFlowTestCase): ...@@ -375,10 +369,9 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
name: "fixed" embedding_dim: 2 size: 1 name: "fixed" embedding_dim: 2 size: 1
} }
""", component_spec) """, component_spec)
with tf.Graph().as_default(): with self.assertRaises(ValueError):
with self.assertRaises(ValueError): unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder( self.master, component_spec)
self.master, component_spec)
def testBulkFeatureIdExtractorExtractFocusWithOffset(self): def testBulkFeatureIdExtractorExtractFocusWithOffset(self):
path = os.path.join(tf.test.get_temp_dir(), 'label-map') path = os.path.join(tf.test.get_temp_dir(), 'label-map')
...@@ -420,67 +413,131 @@ class BulkComponentTest(test_util.TensorFlowTestCase): ...@@ -420,67 +413,131 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
} }
""" % path, master_spec) """ % path, master_spec)
with tf.Graph().as_default(): corpus = _create_fake_corpus()
corpus = _create_fake_corpus() corpus = tf.constant(corpus, shape=[len(corpus)])
corpus = tf.constant(corpus, shape=[len(corpus)]) handle = dragnn_ops.get_session(
handle = dragnn_ops.get_session( container='test',
container='test', master_spec=master_spec.SerializeToString(),
master_spec=master_spec.SerializeToString(), grid_point='')
grid_point='') handle = dragnn_ops.attach_data_reader(handle, corpus)
handle = dragnn_ops.attach_data_reader(handle, corpus) handle = dragnn_ops.init_component_data(
handle = dragnn_ops.init_component_data( handle, beam_size=1, component='test')
handle, beam_size=1, component='test') batch_size = dragnn_ops.batch_size(handle, component='test')
batch_size = dragnn_ops.batch_size(handle, component='test') master_state = component.MasterState(handle, batch_size)
master_state = component.MasterState(handle, batch_size)
extractor = bulk_component.BulkFeatureIdExtractorComponentBuilder(
extractor = bulk_component.BulkFeatureIdExtractorComponentBuilder( self.master, master_spec.component[0])
self.master, master_spec.component[0]) network_state = component.NetworkState()
network_state = component.NetworkState() self.network_states['test'] = network_state
self.network_states['test'] = network_state handle = extractor.build_greedy_inference(master_state, self.network_states)
handle = extractor.build_greedy_inference(master_state, focus1 = network_state.activations['focus1'].bulk_tensor
self.network_states) focus2 = network_state.activations['focus2'].bulk_tensor
focus1 = network_state.activations['focus1'].bulk_tensor focus3 = network_state.activations['focus3'].bulk_tensor
focus2 = network_state.activations['focus2'].bulk_tensor
focus3 = network_state.activations['focus3'].bulk_tensor with self.test_session() as sess:
focus1, focus2, focus3 = sess.run([focus1, focus2, focus3])
with self.test_session() as sess: tf.logging.info('focus1=\n%s', focus1)
focus1, focus2, focus3 = sess.run([focus1, focus2, focus3]) tf.logging.info('focus2=\n%s', focus2)
tf.logging.info('focus1=\n%s', focus1) tf.logging.info('focus3=\n%s', focus3)
tf.logging.info('focus2=\n%s', focus2)
tf.logging.info('focus3=\n%s', focus3) self.assertAllEqual(focus1,
[[0], [-1], [-1], [-1],
self.assertAllEqual( [0], [1], [-1], [-1],
focus1, [0], [1], [2], [-1],
[[0], [-1], [-1], [-1], [0], [1], [2], [3]]) # pyformat: disable
[0], [1], [-1], [-1],
[0], [1], [2], [-1], self.assertAllEqual(focus2,
[0], [1], [2], [3]]) [[-1], [-1], [-1], [-1],
[1], [-1], [-1], [-1],
self.assertAllEqual( [1], [2], [-1], [-1],
focus2, [1], [2], [3], [-1]]) # pyformat: disable
[[-1], [-1], [-1], [-1],
[1], [-1], [-1], [-1], self.assertAllEqual(focus3,
[1], [2], [-1], [-1], [[-1], [-1], [-1], [-1],
[1], [2], [3], [-1]]) [-1], [-1], [-1], [-1],
[2], [-1], [-1], [-1],
self.assertAllEqual( [2], [3], [-1], [-1]]) # pyformat: disable
focus3,
[[-1], [-1], [-1], [-1],
[-1], [-1], [-1], [-1],
[2], [-1], [-1], [-1],
[2], [3], [-1], [-1]])
def testBuildLossFailsOnNoExamples(self): def testBuildLossFailsOnNoExamples(self):
with tf.Graph().as_default(): logits = tf.constant([[0.5], [-0.5], [0.5], [-0.5]])
logits = tf.constant([[0.5], [-0.5], [0.5], [-0.5]]) gold = tf.constant([-1, -1, -1, -1])
gold = tf.constant([-1, -1, -1, -1]) result = bulk_component.build_cross_entropy_loss(logits, gold)
result = bulk_component.build_cross_entropy_loss(logits, gold)
# Expect loss computation to generate a runtime error due to the gold
# Expect loss computation to generate a runtime error due to the gold # tensor containing no valid examples.
# tensor containing no valid examples. with self.test_session() as sess:
with self.test_session() as sess: with self.assertRaises(tf.errors.InvalidArgumentError):
with self.assertRaises(tf.errors.InvalidArgumentError): sess.run(result)
sess.run(result)
def testPreCreateCalledBeforeCreate(self):
component_spec = spec_pb2.ComponentSpec()
text_format.Parse("""
name: "test"
network_unit {
registered_name: "IdentityNetwork"
}
""", component_spec)
class AssertPreCreateBeforeCreateNetwork(
network_units.NetworkUnitInterface):
"""Mock that asserts that .create() is called before .pre_create()."""
def __init__(self, comp, test_fixture):
super(AssertPreCreateBeforeCreateNetwork, self).__init__(comp)
self._test_fixture = test_fixture
self._pre_create_called = False
def get_logits(self, network_tensors):
return tf.zeros([2, 1], dtype=tf.float32)
def pre_create(self, *unused_args):
self._pre_create_called = True
def create(self, *unused_args, **unuesd_kwargs):
self._test_fixture.assertTrue(self._pre_create_called)
return []
builder = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
builder.build_greedy_training(
component.MasterState(['foo', 'bar'], 2), self.network_states)
self.setUp()
builder = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
builder.build_greedy_inference(
component.MasterState(['foo', 'bar'], 2), self.network_states)
self.setUp()
builder = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, component_spec)
builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
builder.build_greedy_training(
component.MasterState(['foo', 'bar'], 2), self.network_states)
self.setUp()
builder = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, component_spec)
builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
builder.build_greedy_inference(
component.MasterState(['foo', 'bar'], 2), self.network_states)
self.setUp()
builder = bulk_component.BulkAnnotatorComponentBuilder(
self.master, component_spec)
builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
builder.build_greedy_training(
component.MasterState(['foo', 'bar'], 2), self.network_states)
self.setUp()
builder = bulk_component.BulkAnnotatorComponentBuilder(
self.master, component_spec)
builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
builder.build_greedy_inference(
component.MasterState(['foo', 'bar'], 2), self.network_states)
if __name__ == '__main__': if __name__ == '__main__':
googletest.main() googletest.main()
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Builds a DRAGNN graph for local training.""" """Builds a DRAGNN graph for local training."""
from abc import ABCMeta from abc import ABCMeta
...@@ -21,12 +20,79 @@ from abc import abstractmethod ...@@ -21,12 +20,79 @@ from abc import abstractmethod
import tensorflow as tf import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from dragnn.protos import export_pb2
from dragnn.python import dragnn_ops from dragnn.python import dragnn_ops
from dragnn.python import network_units from dragnn.python import network_units
from dragnn.python import runtime_support
from syntaxnet.util import check from syntaxnet.util import check
from syntaxnet.util import registry from syntaxnet.util import registry
def build_softmax_cross_entropy_loss(logits, gold):
"""Builds softmax cross entropy loss."""
# A gold label > -1 determines that the sentence is still
# in a valid state. Otherwise, the sentence has ended.
#
# We add only the valid sentences to the loss, in the following way:
# 1. We compute 'valid_ix', the indices in gold that contain
# valid oracle actions.
# 2. We compute the cost function by comparing logits and gold
# only for the valid indices.
valid = tf.greater(gold, -1)
valid_ix = tf.reshape(tf.where(valid), [-1])
valid_gold = tf.gather(gold, valid_ix)
valid_logits = tf.gather(logits, valid_ix)
cost = tf.reduce_sum(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.cast(valid_gold, tf.int64),
logits=valid_logits,
name='sparse_softmax_cross_entropy_with_logits'))
correct = tf.reduce_sum(
tf.to_int32(tf.nn.in_top_k(valid_logits, valid_gold, 1)))
total = tf.size(valid_gold)
return cost, correct, total, valid_logits, valid_gold
def build_sigmoid_cross_entropy_loss(logits, gold, indices, probs):
"""Builds sigmoid cross entropy loss."""
# Filter out entries where gold <= -1, which are batch padding entries.
valid = tf.greater(gold, -1)
valid_ix = tf.reshape(tf.where(valid), [-1])
valid_gold = tf.gather(gold, valid_ix)
valid_indices = tf.gather(indices, valid_ix)
valid_probs = tf.gather(probs, valid_ix)
# NB: tf.gather_nd() raises an error on CPU for out-of-bounds indices. That's
# why we need to filter out the gold=-1 batch padding above.
valid_pairs = tf.stack([valid_indices, valid_gold], axis=1)
valid_logits = tf.gather_nd(logits, valid_pairs)
cost = tf.reduce_sum(
tf.nn.sigmoid_cross_entropy_with_logits(
labels=valid_probs,
logits=valid_logits,
name='sigmoid_cross_entropy_with_logits'))
gold_bool = valid_probs > 0.5
predicted_bool = valid_logits > 0.0
total = tf.size(gold_bool)
with tf.control_dependencies([
tf.assert_equal(
total, tf.size(predicted_bool), name='num_predicted_gold_mismatch')
]):
agreement_bool = tf.logical_not(tf.logical_xor(gold_bool, predicted_bool))
correct = tf.reduce_sum(tf.to_int32(agreement_bool))
cost.set_shape([])
correct.set_shape([])
total.set_shape([])
return cost, correct, total, gold
class NetworkState(object): class NetworkState(object):
"""Simple utility to manage the state of a DRAGNN network. """Simple utility to manage the state of a DRAGNN network.
...@@ -69,6 +135,13 @@ class ComponentBuilderBase(object): ...@@ -69,6 +135,13 @@ class ComponentBuilderBase(object):
As part of the specification, ComponentBuilder will wrap an underlying As part of the specification, ComponentBuilder will wrap an underlying
NetworkUnit which generates the actual network layout. NetworkUnit which generates the actual network layout.
Attributes:
master: dragnn.MasterBuilder that owns this component.
num_actions: Number of actions in the transition system.
name: Name of this component.
spec: dragnn.ComponentSpec that configures this component.
moving_average: True if moving-average parameters should be used.
""" """
__metaclass__ = ABCMeta # required for @abstractmethod __metaclass__ = ABCMeta # required for @abstractmethod
...@@ -96,16 +169,23 @@ class ComponentBuilderBase(object): ...@@ -96,16 +169,23 @@ class ComponentBuilderBase(object):
# Extract component attributes before make_network(), so the network unit # Extract component attributes before make_network(), so the network unit
# can access them. # can access them.
self._attrs = {} self._attrs = {}
global_attr_defaults = {
'locally_normalize': False,
'output_as_probabilities': False
}
if attr_defaults: if attr_defaults:
self._attrs = network_units.get_attrs_with_defaults( global_attr_defaults.update(attr_defaults)
self.spec.component_builder.parameters, attr_defaults) self._attrs = network_units.get_attrs_with_defaults(
self.spec.component_builder.parameters, global_attr_defaults)
do_local_norm = self._attrs['locally_normalize']
self._output_as_probabilities = self._attrs['output_as_probabilities']
with tf.variable_scope(self.name): with tf.variable_scope(self.name):
self.training_beam_size = tf.constant( self.training_beam_size = tf.constant(
self.spec.training_beam_size, name='TrainingBeamSize') self.spec.training_beam_size, name='TrainingBeamSize')
self.inference_beam_size = tf.constant( self.inference_beam_size = tf.constant(
self.spec.inference_beam_size, name='InferenceBeamSize') self.spec.inference_beam_size, name='InferenceBeamSize')
self.locally_normalize = tf.constant(False, name='LocallyNormalize') self.locally_normalize = tf.constant(
do_local_norm, name='LocallyNormalize')
self._step = tf.get_variable( self._step = tf.get_variable(
'step', [], initializer=tf.zeros_initializer(), dtype=tf.int32) 'step', [], initializer=tf.zeros_initializer(), dtype=tf.int32)
self._total = tf.get_variable( self._total = tf.get_variable(
...@@ -120,6 +200,9 @@ class ComponentBuilderBase(object): ...@@ -120,6 +200,9 @@ class ComponentBuilderBase(object):
decay=self.master.hyperparams.average_weight, num_updates=self._step) decay=self.master.hyperparams.average_weight, num_updates=self._step)
self.avg_ops = [self.moving_average.apply(self.network.params)] self.avg_ops = [self.moving_average.apply(self.network.params)]
# Used to export the cell; see add_cell_input() and add_cell_output().
self._cell_subgraph_spec = export_pb2.CellSubgraphSpec()
def make_network(self, network_unit): def make_network(self, network_unit):
"""Makes a NetworkUnitInterface object based on the network_unit spec. """Makes a NetworkUnitInterface object based on the network_unit spec.
...@@ -276,7 +359,7 @@ class ComponentBuilderBase(object): ...@@ -276,7 +359,7 @@ class ComponentBuilderBase(object):
Returns: Returns:
tf.Variable object corresponding to original or averaged version. tf.Variable object corresponding to original or averaged version.
""" """
if var_params: if var_params is not None:
var_name = var_params.name var_name = var_params.name
else: else:
check.NotNone(var_name, 'specify at least one of var_name or var_params') check.NotNone(var_name, 'specify at least one of var_name or var_params')
...@@ -341,6 +424,79 @@ class ComponentBuilderBase(object): ...@@ -341,6 +424,79 @@ class ComponentBuilderBase(object):
"""Returns the value of the component attribute with the |name|.""" """Returns the value of the component attribute with the |name|."""
return self._attrs[name] return self._attrs[name]
def has_attr(self, name):
"""Checks whether a component attribute with the given |name| exists.
Arguments:
name: attribute name
Returns:
'true' if the name exists and 'false' otherwise.
"""
return name in self._attrs
def _add_runtime_hooks(self):
"""Adds "hook" nodes to the graph for use by the runtime, if enabled.
Does nothing if master.build_runtime_graph is False. Subclasses should call
this at the end of build_*_inference(). For details on the runtime hooks,
see runtime_support.py.
"""
if self.master.build_runtime_graph:
with tf.variable_scope(self.name, reuse=True):
runtime_support.add_hooks(self, self._cell_subgraph_spec)
self._cell_subgraph_spec = None # prevent further exports
def add_cell_input(self, dtype, shape, name, input_type='TYPE_FEATURE'):
"""Adds an input to the current CellSubgraphSpec.
Creates a tf.placeholder() with the given |dtype| and |shape|, adds it as a
cell input with the |name| and |input_type|, and returns the placeholder to
be used in place of the actual input tensor.
Args:
dtype: DType of the cell input.
shape: Static shape of the cell input.
name: Logical name of the cell input.
input_type: Name of the appropriate CellSubgraphSpec.Input.Type enum.
Returns:
A tensor to use in place of the actual input tensor.
Raises:
TypeError: If the |shape| is the wrong type.
RuntimeError: If the cell has already been exported.
"""
if not (isinstance(shape, list) and
all(isinstance(dim, int) for dim in shape)):
raise TypeError('shape must be a list of int')
if not self._cell_subgraph_spec:
raise RuntimeError('already exported a CellSubgraphSpec')
with tf.name_scope(None):
tensor = tf.placeholder(
dtype, shape, name='{}/INPUT/{}'.format(self.name, name))
self._cell_subgraph_spec.input.add(
name=name,
tensor=tensor.name,
type=export_pb2.CellSubgraphSpec.Input.Type.Value(input_type))
return tensor
def add_cell_output(self, tensor, name):
"""Adds an output to the current CellSubgraphSpec.
Args:
tensor: Tensor to add as a cell output.
name: Logical name of the cell output.
Raises:
RuntimeError: If the cell has already been exported.
"""
if not self._cell_subgraph_spec:
raise RuntimeError('already exported a CellSubgraphSpec')
self._cell_subgraph_spec.output.add(name=name, tensor=tensor.name)
def update_tensor_arrays(network_tensors, arrays): def update_tensor_arrays(network_tensors, arrays):
"""Updates a list of tensor arrays from the network's output tensors. """Updates a list of tensor arrays from the network's output tensors.
...@@ -370,6 +526,18 @@ class DynamicComponentBuilder(ComponentBuilderBase): ...@@ -370,6 +526,18 @@ class DynamicComponentBuilder(ComponentBuilderBase):
so fixed and linked features can be recurrent. so fixed and linked features can be recurrent.
""" """
def __init__(self, master, component_spec):
"""Initializes the DynamicComponentBuilder from specifications.
Args:
master: dragnn.MasterBuilder object.
component_spec: dragnn.ComponentSpec proto to be built.
"""
super(DynamicComponentBuilder, self).__init__(
master,
component_spec,
attr_defaults={'loss_function': 'softmax_cross_entropy'})
def build_greedy_training(self, state, network_states): def build_greedy_training(self, state, network_states):
"""Builds a training loop for this component. """Builds a training loop for this component.
...@@ -392,9 +560,10 @@ class DynamicComponentBuilder(ComponentBuilderBase): ...@@ -392,9 +560,10 @@ class DynamicComponentBuilder(ComponentBuilderBase):
# Add 0 to training_beam_size to disable eager static evaluation. # Add 0 to training_beam_size to disable eager static evaluation.
# This is possible because tensorflow's constant_value does not # This is possible because tensorflow's constant_value does not
# propagate arithmetic operations. # propagate arithmetic operations.
with tf.control_dependencies([ with tf.control_dependencies(
tf.assert_equal(self.training_beam_size + 0, 1)]): [tf.assert_equal(self.training_beam_size + 0, 1)]):
stride = state.current_batch_size * self.training_beam_size stride = state.current_batch_size * self.training_beam_size
self.network.pre_create(stride)
cost = tf.constant(0.) cost = tf.constant(0.)
correct = tf.constant(0) correct = tf.constant(0)
...@@ -416,40 +585,35 @@ class DynamicComponentBuilder(ComponentBuilderBase): ...@@ -416,40 +585,35 @@ class DynamicComponentBuilder(ComponentBuilderBase):
# Every layer is written to a TensorArray, so that it can be backprop'd. # Every layer is written to a TensorArray, so that it can be backprop'd.
next_arrays = update_tensor_arrays(network_tensors, arrays) next_arrays = update_tensor_arrays(network_tensors, arrays)
loss_function = self.attr('loss_function')
with tf.control_dependencies([x.flow for x in next_arrays]): with tf.control_dependencies([x.flow for x in next_arrays]):
with tf.name_scope('compute_loss'): with tf.name_scope('compute_loss'):
# A gold label > -1 determines that the sentence is still
# in a valid state. Otherwise, the sentence has ended.
#
# We add only the valid sentences to the loss, in the following way:
# 1. We compute 'valid_ix', the indices in gold that contain
# valid oracle actions.
# 2. We compute the cost function by comparing logits and gold
# only for the valid indices.
gold = dragnn_ops.emit_oracle_labels(handle, component=self.name)
gold.set_shape([None])
valid = tf.greater(gold, -1)
valid_ix = tf.reshape(tf.where(valid), [-1])
gold = tf.gather(gold, valid_ix)
logits = self.network.get_logits(network_tensors) logits = self.network.get_logits(network_tensors)
logits = tf.gather(logits, valid_ix) if loss_function == 'softmax_cross_entropy':
gold = dragnn_ops.emit_oracle_labels(handle, component=self.name)
cost += tf.reduce_sum( new_cost, new_correct, new_total, valid_logits, valid_gold = (
tf.nn.sparse_softmax_cross_entropy_with_logits( build_softmax_cross_entropy_loss(logits, gold))
labels=tf.cast(gold, tf.int64), logits=logits))
if (self.eligible_for_self_norm and
if (self.eligible_for_self_norm and self.master.hyperparams.self_norm_alpha > 0):
self.master.hyperparams.self_norm_alpha > 0): log_z = tf.reduce_logsumexp(valid_logits, [1])
log_z = tf.reduce_logsumexp(logits, [1]) new_cost += (self.master.hyperparams.self_norm_alpha *
cost += (self.master.hyperparams.self_norm_alpha * tf.nn.l2_loss(log_z))
tf.nn.l2_loss(log_z)) elif loss_function == 'sigmoid_cross_entropy':
indices, gold, probs = (
correct += tf.reduce_sum( dragnn_ops.emit_oracle_labels_and_probabilities(
tf.to_int32(tf.nn.in_top_k(logits, gold, 1))) handle, component=self.name))
total += tf.size(gold) new_cost, new_correct, new_total, valid_gold = (
build_sigmoid_cross_entropy_loss(logits, gold, indices,
with tf.control_dependencies([cost, correct, total, gold]): probs))
else:
RuntimeError("Unknown loss function '%s'" % loss_function)
cost += new_cost
correct += new_correct
total += new_total
with tf.control_dependencies([cost, correct, total, valid_gold]):
handle = dragnn_ops.advance_from_oracle(handle, component=self.name) handle = dragnn_ops.advance_from_oracle(handle, component=self.name)
return [handle, cost, correct, total] + next_arrays return [handle, cost, correct, total] + next_arrays
...@@ -480,6 +644,7 @@ class DynamicComponentBuilder(ComponentBuilderBase): ...@@ -480,6 +644,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
# Normalize the objective by the total # of steps taken. # Normalize the objective by the total # of steps taken.
# Note: Total could be zero by a number of reasons, including: # Note: Total could be zero by a number of reasons, including:
# * Oracle labels not being emitted. # * Oracle labels not being emitted.
# * All oracle labels for a batch are unknown (-1).
# * No steps being taken if component is terminal at the start of a batch. # * No steps being taken if component is terminal at the start of a batch.
with tf.control_dependencies([tf.assert_greater(total, 0)]): with tf.control_dependencies([tf.assert_greater(total, 0)]):
cost /= tf.to_float(total) cost /= tf.to_float(total)
...@@ -511,6 +676,7 @@ class DynamicComponentBuilder(ComponentBuilderBase): ...@@ -511,6 +676,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
stride = state.current_batch_size * self.training_beam_size stride = state.current_batch_size * self.training_beam_size
else: else:
stride = state.current_batch_size * self.inference_beam_size stride = state.current_batch_size * self.inference_beam_size
self.network.pre_create(stride)
def cond(handle, *_): def cond(handle, *_):
all_final = dragnn_ops.emit_all_final(handle, component=self.name) all_final = dragnn_ops.emit_all_final(handle, component=self.name)
...@@ -559,6 +725,7 @@ class DynamicComponentBuilder(ComponentBuilderBase): ...@@ -559,6 +725,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
for index, layer in enumerate(self.network.layers): for index, layer in enumerate(self.network.layers):
network_state.activations[layer.name] = network_units.StoredActivations( network_state.activations[layer.name] = network_units.StoredActivations(
array=arrays[index]) array=arrays[index])
self._add_runtime_hooks()
with tf.control_dependencies([x.flow for x in arrays]): with tf.control_dependencies([x.flow for x in arrays]):
return tf.identity(state.handle) return tf.identity(state.handle)
...@@ -587,7 +754,7 @@ class DynamicComponentBuilder(ComponentBuilderBase): ...@@ -587,7 +754,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
fixed_embeddings = [] fixed_embeddings = []
for channel_id, feature_spec in enumerate(self.spec.fixed_feature): for channel_id, feature_spec in enumerate(self.spec.fixed_feature):
fixed_embedding = network_units.fixed_feature_lookup( fixed_embedding = network_units.fixed_feature_lookup(
self, state, channel_id, stride) self, state, channel_id, stride, during_training)
if feature_spec.is_constant: if feature_spec.is_constant:
fixed_embedding.tensor = tf.stop_gradient(fixed_embedding.tensor) fixed_embedding.tensor = tf.stop_gradient(fixed_embedding.tensor)
fixed_embeddings.append(fixed_embedding) fixed_embeddings.append(fixed_embedding)
...@@ -633,6 +800,12 @@ class DynamicComponentBuilder(ComponentBuilderBase): ...@@ -633,6 +800,12 @@ class DynamicComponentBuilder(ComponentBuilderBase):
else: else:
attention_tensor = None attention_tensor = None
return self.network.create(fixed_embeddings, linked_embeddings, tensors = self.network.create(fixed_embeddings, linked_embeddings,
context_tensor_arrays, attention_tensor, context_tensor_arrays, attention_tensor,
during_training) during_training)
if self.master.build_runtime_graph:
for index, layer in enumerate(self.network.layers):
self.add_cell_output(tensors[index], layer.name)
return tensors
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for component.py.
"""
import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from google.protobuf import text_format
from dragnn.protos import spec_pb2
from dragnn.python import component
class MockNetworkUnit(object):
def get_layer_size(self, unused_layer_name):
return 64
class MockComponent(object):
def __init__(self):
self.name = 'mock'
self.network = MockNetworkUnit()
class MockMaster(object):
def __init__(self):
self.spec = spec_pb2.MasterSpec()
self.hyperparams = spec_pb2.GridPoint()
self.lookup_component = {'mock': MockComponent()}
self.build_runtime_graph = False
class ComponentTest(test_util.TensorFlowTestCase):
def setUp(self):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf.reset_default_graph()
self.master = MockMaster()
self.master_state = component.MasterState(
handle=tf.constant(['foo', 'bar']), current_batch_size=2)
self.network_states = {
'mock': component.NetworkState(),
'test': component.NetworkState(),
}
def testSoftmaxCrossEntropyLoss(self):
logits = tf.constant([[0.0, 2.0, -1.0],
[-5.0, 1.0, -1.0],
[3.0, 1.0, -2.0]]) # pyformat: disable
gold_labels = tf.constant([1, -1, 1])
cost, correct, total, logits, gold_labels = (
component.build_softmax_cross_entropy_loss(logits, gold_labels))
with self.test_session() as sess:
cost, correct, total, logits, gold_labels = (
sess.run([cost, correct, total, logits, gold_labels]))
# Cost = -2 + ln(1 + exp(2) + exp(-1))
# -1 + ln(exp(3) + exp(1) + exp(-2))
self.assertAlmostEqual(cost, 2.3027, 4)
self.assertEqual(correct, 1)
self.assertEqual(total, 2)
# Entries corresponding to gold labels equal to -1 are skipped.
self.assertAllEqual(logits, [[0.0, 2.0, -1.0], [3.0, 1.0, -2.0]])
self.assertAllEqual(gold_labels, [1, 1])
def testSigmoidCrossEntropyLoss(self):
indices = tf.constant([0, 0, 1])
gold_labels = tf.constant([0, 1, 2])
probs = tf.constant([0.6, 0.7, 0.2])
logits = tf.constant([[0.9, -0.3, 0.1], [-0.5, 0.4, 2.0]])
cost, correct, total, gold_labels = (
component.build_sigmoid_cross_entropy_loss(logits, gold_labels, indices,
probs))
with self.test_session() as sess:
cost, correct, total, gold_labels = (
sess.run([cost, correct, total, gold_labels]))
# The cost corresponding to the three entries is, respectively,
# 0.7012, 0.7644, and 1.7269. Each of them is computed using the formula
# -prob_i * log(sigmoid(logit_i)) - (1-prob_i) * log(1-sigmoid(logit_i))
self.assertAlmostEqual(cost, 3.1924, 4)
self.assertEqual(correct, 1)
self.assertEqual(total, 3)
self.assertAllEqual(gold_labels, [0, 1, 2])
def testGraphConstruction(self):
component_spec = spec_pb2.ComponentSpec()
text_format.Parse("""
name: "test"
network_unit {
registered_name: "IdentityNetwork"
}
fixed_feature {
name: "fixed" embedding_dim: 32 size: 1
}
component_builder {
registered_name: "component.DynamicComponentBuilder"
}
""", component_spec)
comp = component.DynamicComponentBuilder(self.master, component_spec)
comp.build_greedy_training(self.master_state, self.network_states)
def testGraphConstructionWithSigmoidLoss(self):
component_spec = spec_pb2.ComponentSpec()
text_format.Parse("""
name: "test"
network_unit {
registered_name: "IdentityNetwork"
}
fixed_feature {
name: "fixed" embedding_dim: 32 size: 1
}
component_builder {
registered_name: "component.DynamicComponentBuilder"
parameters {
key: "loss_function"
value: "sigmoid_cross_entropy"
}
}
""", component_spec)
comp = component.DynamicComponentBuilder(self.master, component_spec)
comp.build_greedy_training(self.master_state, self.network_states)
# Check that the loss op is present.
op_names = [op.name for op in tf.get_default_graph().get_operations()]
self.assertTrue('train_test/compute_loss/'
'sigmoid_cross_entropy_with_logits' in op_names)
if __name__ == '__main__':
googletest.main()
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
"""TensorFlow ops for directed graphs.""" """TensorFlow ops for directed graphs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from syntaxnet.util import check from syntaxnet.util import check
...@@ -150,7 +154,7 @@ def ArcSourcePotentialsFromTokens(tokens, weights): ...@@ -150,7 +154,7 @@ def ArcSourcePotentialsFromTokens(tokens, weights):
return sources_bxnxn return sources_bxnxn
def RootPotentialsFromTokens(root, tokens, weights): def RootPotentialsFromTokens(root, tokens, weights_arc, weights_source):
r"""Returns root selection potentials computed from tokens and weights. r"""Returns root selection potentials computed from tokens and weights.
For each batch of token activations, computes a scalar potential for each root For each batch of token activations, computes a scalar potential for each root
...@@ -162,7 +166,8 @@ def RootPotentialsFromTokens(root, tokens, weights): ...@@ -162,7 +166,8 @@ def RootPotentialsFromTokens(root, tokens, weights):
Args: Args:
root: [S] vector of activations for the artificial root token. root: [S] vector of activations for the artificial root token.
tokens: [B,N,T] tensor of batched activations for root tokens. tokens: [B,N,T] tensor of batched activations for root tokens.
weights: [S,T] matrix of weights. weights_arc: [S,T] matrix of weights.
weights_source: [S] vector of weights.
B,N may be statically-unknown, but S,T must be statically-known. The dtype B,N may be statically-unknown, but S,T must be statically-known. The dtype
of all arguments must be compatible. of all arguments must be compatible.
...@@ -174,25 +179,30 @@ def RootPotentialsFromTokens(root, tokens, weights): ...@@ -174,25 +179,30 @@ def RootPotentialsFromTokens(root, tokens, weights):
# All arguments must have statically-known rank. # All arguments must have statically-known rank.
check.Eq(root.get_shape().ndims, 1, 'root must be a vector') check.Eq(root.get_shape().ndims, 1, 'root must be a vector')
check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3') check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix') check.Eq(weights_arc.get_shape().ndims, 2, 'weights_arc must be a matrix')
check.Eq(weights_source.get_shape().ndims, 1,
'weights_source must be a vector')
# All activation dimensions must be statically-known. # All activation dimensions must be statically-known.
num_source_activations = weights.get_shape().as_list()[0] num_source_activations = weights_arc.get_shape().as_list()[0]
num_target_activations = weights.get_shape().as_list()[1] num_target_activations = weights_arc.get_shape().as_list()[1]
check.NotNone(num_source_activations, 'unknown source activation dimension') check.NotNone(num_source_activations, 'unknown source activation dimension')
check.NotNone(num_target_activations, 'unknown target activation dimension') check.NotNone(num_target_activations, 'unknown target activation dimension')
check.Eq(root.get_shape().as_list()[0], num_source_activations, check.Eq(root.get_shape().as_list()[0], num_source_activations,
'dimension mismatch between weights and root') 'dimension mismatch between weights_arc and root')
check.Eq(tokens.get_shape().as_list()[2], num_target_activations, check.Eq(tokens.get_shape().as_list()[2], num_target_activations,
'dimension mismatch between weights and tokens') 'dimension mismatch between weights_arc and tokens')
check.Eq(weights_source.get_shape().as_list()[0], num_source_activations,
'dimension mismatch between weights_arc and weights_source')
# All arguments must share the same type. # All arguments must share the same type.
check.Same([weights.dtype.base_dtype, check.Same([
root.dtype.base_dtype, weights_arc.dtype.base_dtype, weights_source.dtype.base_dtype,
tokens.dtype.base_dtype], root.dtype.base_dtype, tokens.dtype.base_dtype
'dtype mismatch') ], 'dtype mismatch')
root_1xs = tf.expand_dims(root, 0) root_1xs = tf.expand_dims(root, 0)
weights_source_sx1 = tf.expand_dims(weights_source, 1)
tokens_shape = tf.shape(tokens) tokens_shape = tf.shape(tokens)
batch_size = tokens_shape[0] batch_size = tokens_shape[0]
...@@ -200,9 +210,12 @@ def RootPotentialsFromTokens(root, tokens, weights): ...@@ -200,9 +210,12 @@ def RootPotentialsFromTokens(root, tokens, weights):
# Flatten out the batch dimension so we can use a couple big matmuls. # Flatten out the batch dimension so we can use a couple big matmuls.
tokens_bnxt = tf.reshape(tokens, [-1, num_target_activations]) tokens_bnxt = tf.reshape(tokens, [-1, num_target_activations])
weights_targets_bnxs = tf.matmul(tokens_bnxt, weights, transpose_b=True) weights_targets_bnxs = tf.matmul(tokens_bnxt, weights_arc, transpose_b=True)
roots_1xbn = tf.matmul(root_1xs, weights_targets_bnxs, transpose_b=True) roots_1xbn = tf.matmul(root_1xs, weights_targets_bnxs, transpose_b=True)
# Add in the score for selecting the root as a source.
roots_1xbn += tf.matmul(root_1xs, weights_source_sx1)
# Restore the batch dimension in the output. # Restore the batch dimension in the output.
roots_bxn = tf.reshape(roots_1xbn, [batch_size, num_tokens]) roots_bxn = tf.reshape(roots_1xbn, [batch_size, num_tokens])
return roots_bxn return roots_bxn
...@@ -354,3 +367,110 @@ def LabelPotentialsFromTokenPairs(sources, targets, weights): ...@@ -354,3 +367,110 @@ def LabelPotentialsFromTokenPairs(sources, targets, weights):
transpose_b=True) transpose_b=True)
labels_bxnxl = tf.squeeze(labels_bxnxlx1, [3]) labels_bxnxl = tf.squeeze(labels_bxnxlx1, [3])
return labels_bxnxl return labels_bxnxl
def ValidArcAndTokenMasks(lengths, max_length, dtype=tf.float32):
r"""Returns 0/1 masks for valid arcs and tokens.
Args:
lengths: [B] vector of input sequence lengths.
max_length: Scalar maximum input sequence length, aka M.
dtype: Data type for output mask.
Returns:
[B,M,M] tensor A with 0/1 indicators of valid arcs. Specifically,
A_{b,t,s} = t,s < lengths[b] ? 1 : 0
[B,M] matrix T with 0/1 indicators of valid tokens. Specifically,
T_{b,t} = t < lengths[b] ? 1 : 0
"""
lengths_bx1 = tf.expand_dims(lengths, 1)
sequence_m = tf.range(tf.cast(max_length, lengths.dtype.base_dtype))
sequence_1xm = tf.expand_dims(sequence_m, 0)
# Create vectors of 0/1 indicators for valid tokens. Note that the comparison
# operator will broadcast from [1,M] and [B,1] to [B,M].
valid_token_bxm = tf.cast(sequence_1xm < lengths_bx1, dtype)
# Compute matrices of 0/1 indicators for valid arcs as the outer product of
# the valid token indicator vector with itself.
valid_arc_bxmxm = tf.matmul(
tf.expand_dims(valid_token_bxm, 2), tf.expand_dims(valid_token_bxm, 1))
return valid_arc_bxmxm, valid_token_bxm
def LaplacianMatrix(lengths, arcs, forest=False):
r"""Returns the (root-augmented) Laplacian matrix for a batch of digraphs.
Args:
lengths: [B] vector of input sequence lengths.
arcs: [B,M,M] tensor of arc potentials where entry b,t,s is the potential of
the arc from s to t in the b'th digraph, while b,t,t is the potential of t
as a root. Entries b,t,s where t or s >= lengths[b] are ignored.
forest: Whether to produce a Laplacian for trees or forests.
Returns:
[B,M,M] tensor L with the Laplacian of each digraph, padded with an identity
matrix. More concretely, the padding entries (t or s >= lengths[b]) are:
L_{b,t,t} = 1.0
L_{b,t,s} = 0.0
Note that this "identity matrix padding" ensures that the determinant of
each padded matrix equals the determinant of the unpadded matrix. The
non-padding entries (t,s < lengths[b]) depend on whether the Laplacian is
constructed for trees or forests. For trees:
L_{b,t,0} = arcs[b,t,t]
L_{b,t,t} = \sum_{s < lengths[b], t != s} arcs[b,t,s]
L_{b,t,s} = -arcs[b,t,s]
For forests:
L_{b,t,t} = \sum_{s < lengths[b]} arcs[b,t,s]
L_{b,t,s} = -arcs[b,t,s]
See http://www.aclweb.org/anthology/D/D07/D07-1015.pdf for details, though
note that our matrices are transposed from their notation.
"""
check.Eq(arcs.get_shape().ndims, 3, 'arcs must be rank 3')
dtype = arcs.dtype.base_dtype
arcs_shape = tf.shape(arcs)
batch_size = arcs_shape[0]
max_length = arcs_shape[1]
with tf.control_dependencies([tf.assert_equal(max_length, arcs_shape[2])]):
valid_arc_bxmxm, valid_token_bxm = ValidArcAndTokenMasks(
lengths, max_length, dtype=dtype)
invalid_token_bxm = tf.constant(1, dtype=dtype) - valid_token_bxm
# Zero out all invalid arcs, to avoid polluting bulk summations.
arcs_bxmxm = arcs * valid_arc_bxmxm
zeros_bxm = tf.zeros([batch_size, max_length], dtype)
if not forest:
# For trees, extract the root potentials and exclude them from the sums
# computed below.
roots_bxm = tf.matrix_diag_part(arcs_bxmxm) # only defined for trees
arcs_bxmxm = tf.matrix_set_diag(arcs_bxmxm, zeros_bxm)
# Sum inbound arc potentials for each target token. These sums will form
# the diagonal of the Laplacian matrix. Note that these sums are zero for
# invalid tokens, since their arc potentials were masked out above.
sums_bxm = tf.reduce_sum(arcs_bxmxm, 2)
if forest:
# For forests, zero out the root potentials after computing the sums above
# so we don't cancel them out when we subtract the arc potentials.
arcs_bxmxm = tf.matrix_set_diag(arcs_bxmxm, zeros_bxm)
# The diagonal of the result is the combination of the arc sums, which are
# non-zero only on valid tokens, and the invalid token indicators, which are
# non-zero only on invalid tokens. Note that the latter form the diagonal
# of the identity matrix padding.
diagonal_bxm = sums_bxm + invalid_token_bxm
# Combine sums and negative arc potentials. Note that the off-diagonal
# padding entries will be zero thanks to the arc mask.
laplacian_bxmxm = tf.matrix_diag(diagonal_bxm) - arcs_bxmxm
if not forest:
# For trees, replace the first column with the root potentials.
roots_bxmx1 = tf.expand_dims(roots_bxm, 2)
laplacian_bxmxm = tf.concat([roots_bxmx1, laplacian_bxmxm[:, :, 1:]], 2)
return laplacian_bxmxm
...@@ -31,16 +31,18 @@ class DigraphOpsTest(tf.test.TestCase): ...@@ -31,16 +31,18 @@ class DigraphOpsTest(tf.test.TestCase):
[3, 4]], [3, 4]],
[[3, 4], [[3, 4],
[2, 3], [2, 3],
[1, 2]]], tf.float32) [1, 2]]],
tf.float32) # pyformat: disable
target_tokens = tf.constant([[[4, 5, 6], target_tokens = tf.constant([[[4, 5, 6],
[5, 6, 7], [5, 6, 7],
[6, 7, 8]], [6, 7, 8]],
[[6, 7, 8], [[6, 7, 8],
[5, 6, 7], [5, 6, 7],
[4, 5, 6]]], tf.float32) [4, 5, 6]]],
tf.float32) # pyformat: disable
weights = tf.constant([[2, 3, 5], weights = tf.constant([[2, 3, 5],
[7, 11, 13]], [7, 11, 13]],
tf.float32) tf.float32) # pyformat: disable
arcs = digraph_ops.ArcPotentialsFromTokens(source_tokens, target_tokens, arcs = digraph_ops.ArcPotentialsFromTokens(source_tokens, target_tokens,
weights) weights)
...@@ -54,7 +56,7 @@ class DigraphOpsTest(tf.test.TestCase): ...@@ -54,7 +56,7 @@ class DigraphOpsTest(tf.test.TestCase):
[803, 957, 1111]], [803, 957, 1111]],
[[1111, 957, 803], # reflected through the center [[1111, 957, 803], # reflected through the center
[815, 702, 589], [815, 702, 589],
[519, 447, 375]]]) [519, 447, 375]]]) # pyformat: disable
def testArcSourcePotentialsFromTokens(self): def testArcSourcePotentialsFromTokens(self):
with self.test_session(): with self.test_session():
...@@ -63,7 +65,7 @@ class DigraphOpsTest(tf.test.TestCase): ...@@ -63,7 +65,7 @@ class DigraphOpsTest(tf.test.TestCase):
[6, 7, 8]], [6, 7, 8]],
[[6, 7, 8], [[6, 7, 8],
[5, 6, 7], [5, 6, 7],
[4, 5, 6]]], tf.float32) [4, 5, 6]]], tf.float32) # pyformat: disable
weights = tf.constant([2, 3, 5], tf.float32) weights = tf.constant([2, 3, 5], tf.float32)
arcs = digraph_ops.ArcSourcePotentialsFromTokens(tokens, weights) arcs = digraph_ops.ArcSourcePotentialsFromTokens(tokens, weights)
...@@ -73,7 +75,7 @@ class DigraphOpsTest(tf.test.TestCase): ...@@ -73,7 +75,7 @@ class DigraphOpsTest(tf.test.TestCase):
[73, 73, 73]], [73, 73, 73]],
[[73, 73, 73], [[73, 73, 73],
[63, 63, 63], [63, 63, 63],
[53, 53, 53]]]) [53, 53, 53]]]) # pyformat: disable
def testRootPotentialsFromTokens(self): def testRootPotentialsFromTokens(self):
with self.test_session(): with self.test_session():
...@@ -83,15 +85,17 @@ class DigraphOpsTest(tf.test.TestCase): ...@@ -83,15 +85,17 @@ class DigraphOpsTest(tf.test.TestCase):
[6, 7, 8]], [6, 7, 8]],
[[6, 7, 8], [[6, 7, 8],
[5, 6, 7], [5, 6, 7],
[4, 5, 6]]], tf.float32) [4, 5, 6]]], tf.float32) # pyformat: disable
weights = tf.constant([[2, 3, 5], weights_arc = tf.constant([[2, 3, 5],
[7, 11, 13]], [7, 11, 13]],
tf.float32) tf.float32) # pyformat: disable
weights_source = tf.constant([11, 10], tf.float32)
roots = digraph_ops.RootPotentialsFromTokens(root, tokens, weights) roots = digraph_ops.RootPotentialsFromTokens(root, tokens, weights_arc,
weights_source)
self.assertAllEqual(roots.eval(), [[375, 447, 519], self.assertAllEqual(roots.eval(), [[406, 478, 550],
[519, 447, 375]]) [550, 478, 406]]) # pyformat: disable
def testCombineArcAndRootPotentials(self): def testCombineArcAndRootPotentials(self):
with self.test_session(): with self.test_session():
...@@ -100,9 +104,9 @@ class DigraphOpsTest(tf.test.TestCase): ...@@ -100,9 +104,9 @@ class DigraphOpsTest(tf.test.TestCase):
[3, 4, 5]], [3, 4, 5]],
[[3, 4, 5], [[3, 4, 5],
[2, 3, 4], [2, 3, 4],
[1, 2, 3]]], tf.float32) [1, 2, 3]]], tf.float32) # pyformat: disable
roots = tf.constant([[6, 7, 8], roots = tf.constant([[6, 7, 8],
[8, 7, 6]], tf.float32) [8, 7, 6]], tf.float32) # pyformat: disable
potentials = digraph_ops.CombineArcAndRootPotentials(arcs, roots) potentials = digraph_ops.CombineArcAndRootPotentials(arcs, roots)
...@@ -111,7 +115,7 @@ class DigraphOpsTest(tf.test.TestCase): ...@@ -111,7 +115,7 @@ class DigraphOpsTest(tf.test.TestCase):
[3, 4, 8]], [3, 4, 8]],
[[8, 4, 5], [[8, 4, 5],
[2, 7, 4], [2, 7, 4],
[1, 2, 6]]]) [1, 2, 6]]]) # pyformat: disable
def testLabelPotentialsFromTokens(self): def testLabelPotentialsFromTokens(self):
with self.test_session(): with self.test_session():
...@@ -120,12 +124,12 @@ class DigraphOpsTest(tf.test.TestCase): ...@@ -120,12 +124,12 @@ class DigraphOpsTest(tf.test.TestCase):
[5, 6]], [5, 6]],
[[6, 5], [[6, 5],
[4, 3], [4, 3],
[2, 1]]], tf.float32) [2, 1]]], tf.float32) # pyformat: disable
weights = tf.constant([[ 2, 3], weights = tf.constant([[ 2, 3],
[ 5, 7], [ 5, 7],
[11, 13]], tf.float32) [11, 13]], tf.float32) # pyformat: disable
labels = digraph_ops.LabelPotentialsFromTokens(tokens, weights) labels = digraph_ops.LabelPotentialsFromTokens(tokens, weights)
...@@ -136,7 +140,7 @@ class DigraphOpsTest(tf.test.TestCase): ...@@ -136,7 +140,7 @@ class DigraphOpsTest(tf.test.TestCase):
[ 28, 67, 133]], [ 28, 67, 133]],
[[ 27, 65, 131], [[ 27, 65, 131],
[ 17, 41, 83], [ 17, 41, 83],
[ 7, 17, 35]]]) [ 7, 17, 35]]]) # pyformat: disable
def testLabelPotentialsFromTokenPairs(self): def testLabelPotentialsFromTokenPairs(self):
with self.test_session(): with self.test_session():
...@@ -145,13 +149,13 @@ class DigraphOpsTest(tf.test.TestCase): ...@@ -145,13 +149,13 @@ class DigraphOpsTest(tf.test.TestCase):
[5, 6]], [5, 6]],
[[6, 5], [[6, 5],
[4, 3], [4, 3],
[2, 1]]], tf.float32) [2, 1]]], tf.float32) # pyformat: disable
targets = tf.constant([[[3, 4], targets = tf.constant([[[3, 4],
[5, 6], [5, 6],
[7, 8]], [7, 8]],
[[8, 7], [[8, 7],
[6, 5], [6, 5],
[4, 3]]], tf.float32) [4, 3]]], tf.float32) # pyformat: disable
weights = tf.constant([[[ 2, 3], weights = tf.constant([[[ 2, 3],
...@@ -159,7 +163,7 @@ class DigraphOpsTest(tf.test.TestCase): ...@@ -159,7 +163,7 @@ class DigraphOpsTest(tf.test.TestCase):
[[11, 13], [[11, 13],
[17, 19]], [17, 19]],
[[23, 29], [[23, 29],
[31, 37]]], tf.float32) [31, 37]]], tf.float32) # pyformat: disable
labels = digraph_ops.LabelPotentialsFromTokenPairs(sources, targets, labels = digraph_ops.LabelPotentialsFromTokenPairs(sources, targets,
weights) weights)
...@@ -171,7 +175,114 @@ class DigraphOpsTest(tf.test.TestCase): ...@@ -171,7 +175,114 @@ class DigraphOpsTest(tf.test.TestCase):
[ 736, 2531, 5043]], [ 736, 2531, 5043]],
[[ 667, 2419, 4857], [[ 667, 2419, 4857],
[ 303, 1115, 2245], [ 303, 1115, 2245],
[ 75, 291, 593]]]) [ 75, 291, 593]]]) # pyformat: disable
def testValidArcAndTokenMasks(self):
with self.test_session():
lengths = tf.constant([1, 2, 3], tf.int64)
max_length = 4
valid_arcs, valid_tokens = digraph_ops.ValidArcAndTokenMasks(
lengths, max_length)
self.assertAllEqual(valid_arcs.eval(),
[[[1, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[1, 1, 1, 0],
[1, 1, 1, 0],
[1, 1, 1, 0],
[0, 0, 0, 0]]]) # pyformat: disable
self.assertAllEqual(valid_tokens.eval(),
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0]]) # pyformat: disable
def testLaplacianMatrixTree(self):
with self.test_session():
pad = 12345.6
arcs = tf.constant([[[ 2, pad, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad]],
[[ 2, 3, pad, pad],
[ 5, 7, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad]],
[[ 2, 3, 5, pad],
[ 7, 11, 13, pad],
[ 17, 19, 23, pad],
[pad, pad, pad, pad]],
[[ 2, 3, 5, 7],
[ 11, 13, 17, 19],
[ 23, 29, 31, 37],
[ 41, 43, 47, 53]]],
tf.float32) # pyformat: disable
lengths = tf.constant([1, 2, 3, 4], tf.int64)
laplacian = digraph_ops.LaplacianMatrix(lengths, arcs)
self.assertAllEqual(laplacian.eval(),
[[[ 2, 0, 0, 0],
[ 0, 1, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 0, 1]],
[[ 2, -3, 0, 0],
[ 7, 5, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 0, 1]],
[[ 2, -3, -5, 0],
[ 11, 20, -13, 0],
[ 23, -19, 36, 0],
[ 0, 0, 0, 1]],
[[ 2, -3, -5, -7],
[ 13, 47, -17, -19],
[ 31, -29, 89, -37],
[ 53, -43, -47, 131]]]) # pyformat: disable
def testLaplacianMatrixForest(self):
with self.test_session():
pad = 12345.6
arcs = tf.constant([[[ 2, pad, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad]],
[[ 2, 3, pad, pad],
[ 5, 7, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad]],
[[ 2, 3, 5, pad],
[ 7, 11, 13, pad],
[ 17, 19, 23, pad],
[pad, pad, pad, pad]],
[[ 2, 3, 5, 7],
[ 11, 13, 17, 19],
[ 23, 29, 31, 37],
[ 41, 43, 47, 53]]],
tf.float32) # pyformat: disable
lengths = tf.constant([1, 2, 3, 4], tf.int64)
laplacian = digraph_ops.LaplacianMatrix(lengths, arcs, forest=True)
self.assertAllEqual(laplacian.eval(),
[[[ 2, 0, 0, 0],
[ 0, 1, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 0, 1]],
[[ 5, -3, 0, 0],
[ -5, 12, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 0, 1]],
[[ 10, -3, -5, 0],
[ -7, 31, -13, 0],
[-17, -19, 59, 0],
[ 0, 0, 0, 1]],
[[ 17, -3, -5, -7],
[-11, 60, -17, -19],
[-23, -29, 120, -37],
[-41, -43, -47, 184]]]) # pyformat: disable
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -25,13 +25,14 @@ from __future__ import absolute_import ...@@ -25,13 +25,14 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
from dragnn.protos import spec_pb2 from dragnn.protos import spec_pb2
from dragnn.python import dragnn_model_saver_lib as saver_lib from dragnn.python import dragnn_model_saver_lib as saver_lib
flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('master_spec', None, 'Path to task context with ' flags.DEFINE_string('master_spec', None, 'Path to task context with '
...@@ -40,10 +41,12 @@ flags.DEFINE_string('params_path', None, 'Path to trained model parameters.') ...@@ -40,10 +41,12 @@ flags.DEFINE_string('params_path', None, 'Path to trained model parameters.')
flags.DEFINE_string('export_path', '', 'Output path for exported servo model.') flags.DEFINE_string('export_path', '', 'Output path for exported servo model.')
flags.DEFINE_bool('export_moving_averages', False, flags.DEFINE_bool('export_moving_averages', False,
'Whether to export the moving average parameters.') 'Whether to export the moving average parameters.')
flags.DEFINE_bool('build_runtime_graph', False,
'Whether to build a graph for use by the runtime.')
def export(master_spec_path, params_path, export_path, def export(master_spec_path, params_path, export_path, export_moving_averages,
export_moving_averages): build_runtime_graph):
"""Restores a model and exports it in SavedModel form. """Restores a model and exports it in SavedModel form.
This method loads a graph specified by the spec at master_spec_path and the This method loads a graph specified by the spec at master_spec_path and the
...@@ -55,6 +58,7 @@ def export(master_spec_path, params_path, export_path, ...@@ -55,6 +58,7 @@ def export(master_spec_path, params_path, export_path,
params_path: Path to the parameters file to export. params_path: Path to the parameters file to export.
export_path: Path to export the SavedModel to. export_path: Path to export the SavedModel to.
export_moving_averages: Whether to export the moving average parameters. export_moving_averages: Whether to export the moving average parameters.
build_runtime_graph: Whether to build a graph for use by the runtime.
""" """
graph = tf.Graph() graph = tf.Graph()
...@@ -70,16 +74,16 @@ def export(master_spec_path, params_path, export_path, ...@@ -70,16 +74,16 @@ def export(master_spec_path, params_path, export_path,
short_to_original = saver_lib.shorten_resource_paths(master_spec) short_to_original = saver_lib.shorten_resource_paths(master_spec)
saver_lib.export_master_spec(master_spec, graph) saver_lib.export_master_spec(master_spec, graph)
saver_lib.export_to_graph(master_spec, params_path, stripped_path, graph, saver_lib.export_to_graph(master_spec, params_path, stripped_path, graph,
export_moving_averages) export_moving_averages, build_runtime_graph)
saver_lib.export_assets(master_spec, short_to_original, stripped_path) saver_lib.export_assets(master_spec, short_to_original, stripped_path)
def main(unused_argv): def main(unused_argv):
# Run the exporter. # Run the exporter.
export(FLAGS.master_spec, FLAGS.params_path, export(FLAGS.master_spec, FLAGS.params_path, FLAGS.export_path,
FLAGS.export_path, FLAGS.export_moving_averages) FLAGS.export_moving_averages, FLAGS.build_runtime_graph)
tf.logging.info('Export complete.') tf.logging.info('Export complete.')
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() app.run(main)
...@@ -164,6 +164,7 @@ def export_to_graph(master_spec, ...@@ -164,6 +164,7 @@ def export_to_graph(master_spec,
export_path, export_path,
external_graph, external_graph,
export_moving_averages, export_moving_averages,
build_runtime_graph,
signature_name='model'): signature_name='model'):
"""Restores a model and exports it in SavedModel form. """Restores a model and exports it in SavedModel form.
...@@ -177,6 +178,7 @@ def export_to_graph(master_spec, ...@@ -177,6 +178,7 @@ def export_to_graph(master_spec,
export_path: Path to export the SavedModel to. export_path: Path to export the SavedModel to.
external_graph: A tf.Graph() object to build the graph inside. external_graph: A tf.Graph() object to build the graph inside.
export_moving_averages: Whether to export the moving average parameters. export_moving_averages: Whether to export the moving average parameters.
build_runtime_graph: Whether to build a graph for use by the runtime.
signature_name: Name of the signature to insert. signature_name: Name of the signature to insert.
""" """
tf.logging.info( tf.logging.info(
...@@ -189,7 +191,7 @@ def export_to_graph(master_spec, ...@@ -189,7 +191,7 @@ def export_to_graph(master_spec,
hyperparam_config.use_moving_average = export_moving_averages hyperparam_config.use_moving_average = export_moving_averages
builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
post_restore_hook = builder.build_post_restore_hook() post_restore_hook = builder.build_post_restore_hook()
annotation = builder.add_annotation() annotation = builder.add_annotation(build_runtime_graph=build_runtime_graph)
builder.add_saver() builder.add_saver()
# Resets session. # Resets session.
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Test for dragnn.python.dragnn_model_saver_lib.""" """Test for dragnn.python.dragnn_model_saver_lib."""
from __future__ import absolute_import from __future__ import absolute_import
...@@ -26,24 +25,30 @@ import tensorflow as tf ...@@ -26,24 +25,30 @@ import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from dragnn.protos import export_pb2
from dragnn.protos import spec_pb2 from dragnn.protos import spec_pb2
from dragnn.python import dragnn_model_saver_lib from dragnn.python import dragnn_model_saver_lib
from syntaxnet import sentence_pb2
from syntaxnet import test_flags
FLAGS = tf.app.flags.FLAGS _DUMMY_TEST_SENTENCE = """
token {
word: "sentence" start: 0 end: 7 break_level: NO_BREAK
def setUpModule(): }
if not hasattr(FLAGS, 'test_srcdir'): token {
FLAGS.test_srcdir = '' word: "0" start: 9 end: 9 break_level: SPACE_BREAK
if not hasattr(FLAGS, 'test_tmpdir'): }
FLAGS.test_tmpdir = tf.test.get_temp_dir() token {
word: "." start: 10 end: 10 break_level: NO_BREAK
}
"""
class DragnnModelSaverLibTest(test_util.TensorFlowTestCase): class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
def LoadSpec(self, spec_path): def LoadSpec(self, spec_path):
master_spec = spec_pb2.MasterSpec() master_spec = spec_pb2.MasterSpec()
root_dir = os.path.join(FLAGS.test_srcdir, root_dir = os.path.join(test_flags.source_root(),
'dragnn/python') 'dragnn/python')
with open(os.path.join(root_dir, 'testdata', spec_path), 'r') as fin: with open(os.path.join(root_dir, 'testdata', spec_path), 'r') as fin:
text_format.Parse(fin.read().replace('TOPDIR', root_dir), master_spec) text_format.Parse(fin.read().replace('TOPDIR', root_dir), master_spec)
...@@ -52,7 +57,7 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase): ...@@ -52,7 +57,7 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
def CreateLocalSpec(self, spec_path): def CreateLocalSpec(self, spec_path):
master_spec = self.LoadSpec(spec_path) master_spec = self.LoadSpec(spec_path)
master_spec_name = os.path.basename(spec_path) master_spec_name = os.path.basename(spec_path)
outfile = os.path.join(FLAGS.test_tmpdir, master_spec_name) outfile = os.path.join(test_flags.temp_dir(), master_spec_name)
fout = open(outfile, 'w') fout = open(outfile, 'w')
fout.write(text_format.MessageToString(master_spec)) fout.write(text_format.MessageToString(master_spec))
return outfile return outfile
...@@ -80,16 +85,50 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase): ...@@ -80,16 +85,50 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
# Return a set of all unique paths. # Return a set of all unique paths.
return set(path_list) return set(path_list)
def GetHookNodeNames(self, master_spec):
"""Returns hook node names to use in tests.
Args:
master_spec: MasterSpec proto from which to infer hook node names.
Returns:
Tuple of (averaged hook node name, non-averaged hook node name, cell
subgraph hook node name).
Raises:
ValueError: If hook nodes cannot be inferred from the |master_spec|.
"""
# Find an op name we can use for testing runtime hooks. Assume that at
# least one component has a fixed feature (else what is the model doing?).
component_name = None
for component_spec in master_spec.component:
if component_spec.fixed_feature:
component_name = component_spec.name
break
if not component_name:
raise ValueError('Cannot infer hook node names')
non_averaged_hook_name = '{}/fixed_embedding_matrix_0/trimmed'.format(
component_name)
averaged_hook_name = '{}/ExponentialMovingAverage'.format(
non_averaged_hook_name)
cell_subgraph_hook_name = '{}/EXPORT/CellSubgraphSpec'.format(
component_name)
return averaged_hook_name, non_averaged_hook_name, cell_subgraph_hook_name
def testModelExport(self): def testModelExport(self):
# Get the master spec and params for this graph. # Get the master spec and params for this graph.
master_spec = self.LoadSpec('ud-hungarian.master-spec') master_spec = self.LoadSpec('ud-hungarian.master-spec')
params_path = os.path.join( params_path = os.path.join(
FLAGS.test_srcdir, 'dragnn/python/testdata' test_flags.source_root(),
'dragnn/python/testdata'
'/ud-hungarian.params') '/ud-hungarian.params')
# Export the graph via SavedModel. (Here, we maintain a handle to the graph # Export the graph via SavedModel. (Here, we maintain a handle to the graph
# for comparison, but that's usually not necessary.) # for comparison, but that's usually not necessary.)
export_path = os.path.join(FLAGS.test_tmpdir, 'export') export_path = os.path.join(test_flags.temp_dir(), 'export')
dragnn_model_saver_lib.clean_output_paths(export_path)
saver_graph = tf.Graph() saver_graph = tf.Graph()
shortened_to_original = dragnn_model_saver_lib.shorten_resource_paths( shortened_to_original = dragnn_model_saver_lib.shorten_resource_paths(
...@@ -102,7 +141,8 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase): ...@@ -102,7 +141,8 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
params_path, params_path,
export_path, export_path,
saver_graph, saver_graph,
export_moving_averages=False) export_moving_averages=False,
build_runtime_graph=False)
# Export the assets as well. # Export the assets as well.
dragnn_model_saver_lib.export_assets(master_spec, shortened_to_original, dragnn_model_saver_lib.export_assets(master_spec, shortened_to_original,
...@@ -126,6 +166,165 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase): ...@@ -126,6 +166,165 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING],
export_path) export_path)
averaged_hook_name, non_averaged_hook_name, _ = self.GetHookNodeNames(
master_spec)
# Check that the averaged runtime hook node does not exist.
with self.assertRaises(KeyError):
restored_graph.get_operation_by_name(averaged_hook_name)
# Check that the non-averaged version also does not exist.
with self.assertRaises(KeyError):
restored_graph.get_operation_by_name(non_averaged_hook_name)
def testModelExportWithAveragesAndHooks(self):
# Get the master spec and params for this graph.
master_spec = self.LoadSpec('ud-hungarian.master-spec')
params_path = os.path.join(
test_flags.source_root(),
'dragnn/python/testdata'
'/ud-hungarian.params')
# Export the graph via SavedModel. (Here, we maintain a handle to the graph
# for comparison, but that's usually not necessary.) Note that the export
# path must not already exist.
export_path = os.path.join(test_flags.temp_dir(), 'export2')
dragnn_model_saver_lib.clean_output_paths(export_path)
saver_graph = tf.Graph()
shortened_to_original = dragnn_model_saver_lib.shorten_resource_paths(
master_spec)
dragnn_model_saver_lib.export_master_spec(master_spec, saver_graph)
dragnn_model_saver_lib.export_to_graph(
master_spec,
params_path,
export_path,
saver_graph,
export_moving_averages=True,
build_runtime_graph=True)
# Export the assets as well.
dragnn_model_saver_lib.export_assets(master_spec, shortened_to_original,
export_path)
# Validate that the assets are all in the exported directory.
path_set = self.ValidateAssetExistence(master_spec, export_path)
# This master-spec has 4 unique assets. If there are more, we have not
# uniquified the assets properly.
self.assertEqual(len(path_set), 4)
# Restore the graph from the checkpoint into a new Graph object.
restored_graph = tf.Graph()
restoration_config = tf.ConfigProto(
log_device_placement=False,
intra_op_parallelism_threads=10,
inter_op_parallelism_threads=10)
with tf.Session(graph=restored_graph, config=restoration_config) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING],
export_path)
averaged_hook_name, non_averaged_hook_name, cell_subgraph_hook_name = (
self.GetHookNodeNames(master_spec))
# Check that an averaged runtime hook node exists.
restored_graph.get_operation_by_name(averaged_hook_name)
# Check that the non-averaged version does not exist.
with self.assertRaises(KeyError):
restored_graph.get_operation_by_name(non_averaged_hook_name)
# Load the cell subgraph.
cell_subgraph_bytes = restored_graph.get_tensor_by_name(
cell_subgraph_hook_name + ':0')
cell_subgraph_bytes = cell_subgraph_bytes.eval(
feed_dict={'annotation/ComputeSession/InputBatch:0': []})
cell_subgraph_spec = export_pb2.CellSubgraphSpec()
cell_subgraph_spec.ParseFromString(cell_subgraph_bytes)
tf.logging.info('cell_subgraph_spec = %s', cell_subgraph_spec)
# Sanity check inputs.
for cell_input in cell_subgraph_spec.input:
self.assertGreater(len(cell_input.name), 0)
self.assertGreater(len(cell_input.tensor), 0)
self.assertNotEqual(cell_input.type,
export_pb2.CellSubgraphSpec.Input.TYPE_UNKNOWN)
restored_graph.get_tensor_by_name(cell_input.tensor) # shouldn't raise
# Sanity check outputs.
for cell_output in cell_subgraph_spec.output:
self.assertGreater(len(cell_output.name), 0)
self.assertGreater(len(cell_output.tensor), 0)
restored_graph.get_tensor_by_name(cell_output.tensor) # shouldn't raise
# GetHookNames() finds a component with a fixed feature, so at least the
# first feature ID should exist.
self.assertTrue(
any(cell_input.name == 'fixed_channel_0_index_0_ids'
for cell_input in cell_subgraph_spec.input))
# Most dynamic components produce a logits layer.
self.assertTrue(
any(cell_output.name == 'logits'
for cell_output in cell_subgraph_spec.output))
def testModelExportProducesRunnableModel(self):
# Get the master spec and params for this graph.
master_spec = self.LoadSpec('ud-hungarian.master-spec')
params_path = os.path.join(
test_flags.source_root(),
'dragnn/python/testdata'
'/ud-hungarian.params')
# Export the graph via SavedModel. (Here, we maintain a handle to the graph
# for comparison, but that's usually not necessary.)
export_path = os.path.join(test_flags.temp_dir(), 'export')
dragnn_model_saver_lib.clean_output_paths(export_path)
saver_graph = tf.Graph()
shortened_to_original = dragnn_model_saver_lib.shorten_resource_paths(
master_spec)
dragnn_model_saver_lib.export_master_spec(master_spec, saver_graph)
dragnn_model_saver_lib.export_to_graph(
master_spec,
params_path,
export_path,
saver_graph,
export_moving_averages=False,
build_runtime_graph=False)
# Export the assets as well.
dragnn_model_saver_lib.export_assets(master_spec, shortened_to_original,
export_path)
# Restore the graph from the checkpoint into a new Graph object.
restored_graph = tf.Graph()
restoration_config = tf.ConfigProto(
log_device_placement=False,
intra_op_parallelism_threads=10,
inter_op_parallelism_threads=10)
with tf.Session(graph=restored_graph, config=restoration_config) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING],
export_path)
test_doc = sentence_pb2.Sentence()
text_format.Parse(_DUMMY_TEST_SENTENCE, test_doc)
test_reader_string = test_doc.SerializeToString()
test_inputs = [test_reader_string]
tf_out = sess.run(
'annotation/annotations:0',
feed_dict={'annotation/ComputeSession/InputBatch:0': test_inputs})
# We don't care about accuracy, only that the run sessions don't crash.
del tf_out
if __name__ == '__main__': if __name__ == '__main__':
googletest.main() googletest.main()
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Diff test that compares two files are identical."""
from absl import flags
import tensorflow as tf
FLAGS = flags.FLAGS
flags.DEFINE_string('actual_file', None, 'File to test.')
flags.DEFINE_string('expected_file', None, 'File with expected contents.')
class DiffTest(tf.test.TestCase):
def testEqualFiles(self):
content_actual = None
content_expected = None
try:
with open(FLAGS.actual_file) as actual:
content_actual = actual.read()
except IOError as e:
self.fail("Error opening '%s': %s" % (FLAGS.actual_file, e.strerror))
try:
with open(FLAGS.expected_file) as expected:
content_expected = expected.read()
except IOError as e:
self.fail("Error opening '%s': %s" % (FLAGS.expected_file, e.strerror))
self.assertTrue(content_actual == content_expected)
if __name__ == '__main__':
tf.test.main()
...@@ -28,7 +28,7 @@ from syntaxnet.util import check ...@@ -28,7 +28,7 @@ from syntaxnet.util import check
try: try:
tf.NotDifferentiable('ExtractFixedFeatures') tf.NotDifferentiable('ExtractFixedFeatures')
except KeyError as e: except KeyError, e:
logging.info(str(e)) logging.info(str(e))
...@@ -179,6 +179,8 @@ class MasterBuilder(object): ...@@ -179,6 +179,8 @@ class MasterBuilder(object):
optimizer: handle to the tf.train Optimizer object used to train this model. optimizer: handle to the tf.train Optimizer object used to train this model.
master_vars: dictionary of globally shared tf.Variable objects (e.g. master_vars: dictionary of globally shared tf.Variable objects (e.g.
the global training step and learning rate.) the global training step and learning rate.)
read_from_avg: Whether to use averaged params instead of normal params.
build_runtime_graph: Whether to build a graph for use by the runtime.
""" """
def __init__(self, master_spec, hyperparam_config=None, pool_scope='shared'): def __init__(self, master_spec, hyperparam_config=None, pool_scope='shared'):
...@@ -197,14 +199,15 @@ class MasterBuilder(object): ...@@ -197,14 +199,15 @@ class MasterBuilder(object):
ValueError: if a component is not found in the registry. ValueError: if a component is not found in the registry.
""" """
self.spec = master_spec self.spec = master_spec
self.hyperparams = (spec_pb2.GridPoint() self.hyperparams = (
if hyperparam_config is None else hyperparam_config) spec_pb2.GridPoint()
if hyperparam_config is None else hyperparam_config)
_validate_grid_point(self.hyperparams) _validate_grid_point(self.hyperparams)
self.pool_scope = pool_scope self.pool_scope = pool_scope
# Set the graph-level random seed before creating the Components so the ops # Set the graph-level random seed before creating the Components so the ops
# they create will use this seed. # they create will use this seed.
tf.set_random_seed(hyperparam_config.seed) tf.set_random_seed(self.hyperparams.seed)
# Construct all utility class and variables for each Component. # Construct all utility class and variables for each Component.
self.components = [] self.components = []
...@@ -219,19 +222,37 @@ class MasterBuilder(object): ...@@ -219,19 +222,37 @@ class MasterBuilder(object):
self.lookup_component[comp.name] = comp self.lookup_component[comp.name] = comp
self.components.append(comp) self.components.append(comp)
# Add global step variable.
self.master_vars = {} self.master_vars = {}
with tf.variable_scope('master', reuse=False): with tf.variable_scope('master', reuse=False):
# Add global step variable.
self.master_vars['step'] = tf.get_variable( self.master_vars['step'] = tf.get_variable(
'step', [], initializer=tf.zeros_initializer(), dtype=tf.int32) 'step', [], initializer=tf.zeros_initializer(), dtype=tf.int32)
self.master_vars['learning_rate'] = _create_learning_rate(
self.hyperparams, self.master_vars['step']) # Add learning rate. If the learning rate is optimized externally, then
# just create an assign op.
if self.hyperparams.pbt_optimize_learning_rate:
self.master_vars['learning_rate'] = tf.get_variable(
'learning_rate',
initializer=tf.constant(
self.hyperparams.learning_rate, dtype=tf.float32))
lr_assign_input = tf.placeholder(tf.float32, [],
'pbt/assign/learning_rate/Value')
tf.assign(
self.master_vars['learning_rate'],
value=lr_assign_input,
name='pbt/assign/learning_rate')
else:
self.master_vars['learning_rate'] = _create_learning_rate(
self.hyperparams, self.master_vars['step'])
# Construct optimizer. # Construct optimizer.
self.optimizer = _create_optimizer(self.hyperparams, self.optimizer = _create_optimizer(self.hyperparams,
self.master_vars['learning_rate'], self.master_vars['learning_rate'],
self.master_vars['step']) self.master_vars['step'])
self.read_from_avg = False
self.build_runtime_graph = False
@property @property
def component_names(self): def component_names(self):
return tuple(c.name for c in self.components) return tuple(c.name for c in self.components)
...@@ -366,8 +387,9 @@ class MasterBuilder(object): ...@@ -366,8 +387,9 @@ class MasterBuilder(object):
max_index = len(self.components) max_index = len(self.components)
else: else:
if not 0 < max_index <= len(self.components): if not 0 < max_index <= len(self.components):
raise IndexError('Invalid max_index {} for components {}; handle {}'. raise IndexError(
format(max_index, self.component_names, handle.name)) 'Invalid max_index {} for components {}; handle {}'.format(
max_index, self.component_names, handle.name))
# By default, we train every component supervised. # By default, we train every component supervised.
if not component_weights: if not component_weights:
...@@ -375,6 +397,11 @@ class MasterBuilder(object): ...@@ -375,6 +397,11 @@ class MasterBuilder(object):
if not unroll_using_oracle: if not unroll_using_oracle:
unroll_using_oracle = [True] * max_index unroll_using_oracle = [True] * max_index
if not max_index <= len(unroll_using_oracle):
raise IndexError(('Invalid max_index {} for unroll_using_oracle {}; '
'handle {}').format(max_index, unroll_using_oracle,
handle.name))
component_weights = component_weights[:max_index] component_weights = component_weights[:max_index]
total_weight = (float)(sum(component_weights)) total_weight = (float)(sum(component_weights))
component_weights = [w / total_weight for w in component_weights] component_weights = [w / total_weight for w in component_weights]
...@@ -408,10 +435,10 @@ class MasterBuilder(object): ...@@ -408,10 +435,10 @@ class MasterBuilder(object):
args = (master_state, network_states) args = (master_state, network_states)
if unroll_using_oracle[component_index]: if unroll_using_oracle[component_index]:
handle, component_cost, component_correct, component_total = (tf.cond( handle, component_cost, component_correct, component_total = (
comp.training_beam_size > 1, tf.cond(comp.training_beam_size > 1,
lambda: comp.build_structured_training(*args), lambda: comp.build_structured_training(*args),
lambda: comp.build_greedy_training(*args))) lambda: comp.build_greedy_training(*args)))
else: else:
handle = comp.build_greedy_inference(*args, during_training=True) handle = comp.build_greedy_inference(*args, during_training=True)
...@@ -445,6 +472,7 @@ class MasterBuilder(object): ...@@ -445,6 +472,7 @@ class MasterBuilder(object):
# 1. compute the gradients, # 1. compute the gradients,
# 2. add an optimizer to update the parameters using the gradients, # 2. add an optimizer to update the parameters using the gradients,
# 3. make the ComputeSession handle depend on the optimizer. # 3. make the ComputeSession handle depend on the optimizer.
gradient_norm = tf.constant(0.)
if compute_gradients: if compute_gradients:
logging.info('Creating train op with %d variables:\n\t%s', logging.info('Creating train op with %d variables:\n\t%s',
len(params_to_train), len(params_to_train),
...@@ -452,8 +480,11 @@ class MasterBuilder(object): ...@@ -452,8 +480,11 @@ class MasterBuilder(object):
grads_and_vars = self.optimizer.compute_gradients( grads_and_vars = self.optimizer.compute_gradients(
cost, var_list=params_to_train) cost, var_list=params_to_train)
clipped_gradients = [(self._clip_gradients(g), v) clipped_gradients = [
for g, v in grads_and_vars] (self._clip_gradients(g), v) for g, v in grads_and_vars
]
gradient_norm = tf.global_norm(list(zip(*clipped_gradients))[0])
minimize_op = self.optimizer.apply_gradients( minimize_op = self.optimizer.apply_gradients(
clipped_gradients, global_step=self.master_vars['step']) clipped_gradients, global_step=self.master_vars['step'])
...@@ -474,6 +505,7 @@ class MasterBuilder(object): ...@@ -474,6 +505,7 @@ class MasterBuilder(object):
# Returns named access to common outputs. # Returns named access to common outputs.
outputs = { outputs = {
'cost': cost, 'cost': cost,
'gradient_norm': gradient_norm,
'batch': effective_batch, 'batch': effective_batch,
'metrics': metrics, 'metrics': metrics,
} }
...@@ -520,7 +552,10 @@ class MasterBuilder(object): ...@@ -520,7 +552,10 @@ class MasterBuilder(object):
with tf.control_dependencies(control_ops): with tf.control_dependencies(control_ops):
return tf.no_op(name='post_restore_hook_master') return tf.no_op(name='post_restore_hook_master')
def build_inference(self, handle, use_moving_average=False): def build_inference(self,
handle,
use_moving_average=False,
build_runtime_graph=False):
"""Builds an inference pipeline. """Builds an inference pipeline.
This always uses the whole pipeline. This always uses the whole pipeline.
...@@ -530,25 +565,30 @@ class MasterBuilder(object): ...@@ -530,25 +565,30 @@ class MasterBuilder(object):
use_moving_average: Whether or not to read from the moving use_moving_average: Whether or not to read from the moving
average variables instead of the true parameters. Note: it is not average variables instead of the true parameters. Note: it is not
possible to make gradient updates when this is True. possible to make gradient updates when this is True.
build_runtime_graph: Whether to build a graph for use by the runtime.
Returns: Returns:
handle: Handle after annotation. handle: Handle after annotation.
""" """
self.read_from_avg = use_moving_average self.read_from_avg = use_moving_average
self.build_runtime_graph = build_runtime_graph
network_states = {} network_states = {}
for comp in self.components: for comp in self.components:
network_states[comp.name] = component.NetworkState() network_states[comp.name] = component.NetworkState()
handle = dragnn_ops.init_component_data( handle = dragnn_ops.init_component_data(
handle, beam_size=comp.inference_beam_size, component=comp.name) handle, beam_size=comp.inference_beam_size, component=comp.name)
master_state = component.MasterState(handle, if build_runtime_graph:
dragnn_ops.batch_size( batch_size = 1 # runtime uses singleton batches
handle, component=comp.name)) else:
batch_size = dragnn_ops.batch_size(handle, component=comp.name)
master_state = component.MasterState(handle, batch_size)
with tf.control_dependencies([handle]): with tf.control_dependencies([handle]):
handle = comp.build_greedy_inference(master_state, network_states) handle = comp.build_greedy_inference(master_state, network_states)
handle = dragnn_ops.write_annotations(handle, component=comp.name) handle = dragnn_ops.write_annotations(handle, component=comp.name)
self.read_from_avg = False self.read_from_avg = False
self.build_runtime_graph = False
return handle return handle
def add_training_from_config(self, def add_training_from_config(self,
...@@ -625,7 +665,10 @@ class MasterBuilder(object): ...@@ -625,7 +665,10 @@ class MasterBuilder(object):
return self._outputs_with_release(handle, {'input_batch': input_batch}, return self._outputs_with_release(handle, {'input_batch': input_batch},
outputs) outputs)
def add_annotation(self, name_scope='annotation', enable_tracing=False): def add_annotation(self,
name_scope='annotation',
enable_tracing=False,
build_runtime_graph=False):
"""Adds an annotation pipeline to the graph. """Adds an annotation pipeline to the graph.
This will create the following additional named targets by default, for use This will create the following additional named targets by default, for use
...@@ -640,13 +683,17 @@ class MasterBuilder(object): ...@@ -640,13 +683,17 @@ class MasterBuilder(object):
enable_tracing: Enabling this will result in two things: enable_tracing: Enabling this will result in two things:
1. Tracing will be enabled during inference. 1. Tracing will be enabled during inference.
2. A 'traces' node will be added to the outputs. 2. A 'traces' node will be added to the outputs.
build_runtime_graph: Whether to build a graph for use by the runtime.
Returns: Returns:
A dictionary of input and output nodes. A dictionary of input and output nodes.
""" """
with tf.name_scope(name_scope): with tf.name_scope(name_scope):
handle, input_batch = self._get_session_with_reader(enable_tracing) handle, input_batch = self._get_session_with_reader(enable_tracing)
handle = self.build_inference(handle, use_moving_average=True) handle = self.build_inference(
handle,
use_moving_average=True,
build_runtime_graph=build_runtime_graph)
annotations = dragnn_ops.emit_annotations( annotations = dragnn_ops.emit_annotations(
handle, component=self.spec.component[-1].name) handle, component=self.spec.component[-1].name)
...@@ -666,7 +713,7 @@ class MasterBuilder(object): ...@@ -666,7 +713,7 @@ class MasterBuilder(object):
def add_saver(self): def add_saver(self):
"""Adds a Saver for all variables in the graph.""" """Adds a Saver for all variables in the graph."""
logging.info('Saving variables:\n\t%s', logging.info('Generating op to save variables:\n\t%s',
'\n\t'.join([x.name for x in tf.global_variables()])) '\n\t'.join([x.name for x in tf.global_variables()]))
self.saver = tf.train.Saver( self.saver = tf.train.Saver(
var_list=[x for x in tf.global_variables()], var_list=[x for x in tf.global_variables()],
......
...@@ -20,7 +20,6 @@ import os.path ...@@ -20,7 +20,6 @@ import os.path
import numpy as np import numpy as np
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
...@@ -30,13 +29,12 @@ from dragnn.protos import trace_pb2 ...@@ -30,13 +29,12 @@ from dragnn.protos import trace_pb2
from dragnn.python import dragnn_ops from dragnn.python import dragnn_ops
from dragnn.python import graph_builder from dragnn.python import graph_builder
from syntaxnet import sentence_pb2 from syntaxnet import sentence_pb2
from syntaxnet import test_flags
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
FLAGS = tf.app.flags.FLAGS
_DUMMY_GOLD_SENTENCE = """ _DUMMY_GOLD_SENTENCE = """
token { token {
...@@ -151,13 +149,6 @@ token { ...@@ -151,13 +149,6 @@ token {
] ]
def setUpModule():
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
def _as_op(x): def _as_op(x):
"""Always returns the tf.Operation associated with a node.""" """Always returns the tf.Operation associated with a node."""
return x.op if isinstance(x, tf.Tensor) else x return x.op if isinstance(x, tf.Tensor) else x
...@@ -244,7 +235,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase): ...@@ -244,7 +235,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
def LoadSpec(self, spec_path): def LoadSpec(self, spec_path):
master_spec = spec_pb2.MasterSpec() master_spec = spec_pb2.MasterSpec()
testdata = os.path.join(FLAGS.test_srcdir, testdata = os.path.join(test_flags.source_root(),
'dragnn/core/testdata') 'dragnn/core/testdata')
with open(os.path.join(testdata, spec_path), 'r') as fin: with open(os.path.join(testdata, spec_path), 'r') as fin:
text_format.Parse(fin.read().replace('TESTDATA', testdata), master_spec) text_format.Parse(fin.read().replace('TESTDATA', testdata), master_spec)
...@@ -445,7 +436,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase): ...@@ -445,7 +436,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
self.assertEqual(expected_num_actions, correct_val) self.assertEqual(expected_num_actions, correct_val)
self.assertEqual(expected_num_actions, total_val) self.assertEqual(expected_num_actions, total_val)
builder.saver.save(sess, os.path.join(FLAGS.test_tmpdir, 'model')) builder.saver.save(sess, os.path.join(test_flags.temp_dir(), 'model'))
logging.info('Running test.') logging.info('Running test.')
logging.info('Printing annotations') logging.info('Printing annotations')
......
...@@ -27,8 +27,7 @@ from dragnn.python import lexicon ...@@ -27,8 +27,7 @@ from dragnn.python import lexicon
from syntaxnet import parser_trainer from syntaxnet import parser_trainer
from syntaxnet import task_spec_pb2 from syntaxnet import task_spec_pb2
from syntaxnet import test_flags
FLAGS = tf.app.flags.FLAGS
_EXPECTED_CONTEXT = r""" _EXPECTED_CONTEXT = r"""
...@@ -46,13 +45,6 @@ input { name: "known-word-map" Part { file_pattern: "/tmp/known-word-map" } } ...@@ -46,13 +45,6 @@ input { name: "known-word-map" Part { file_pattern: "/tmp/known-word-map" } }
""" """
def setUpModule():
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
class LexiconTest(tf.test.TestCase): class LexiconTest(tf.test.TestCase):
def testCreateLexiconContext(self): def testCreateLexiconContext(self):
...@@ -62,8 +54,8 @@ class LexiconTest(tf.test.TestCase): ...@@ -62,8 +54,8 @@ class LexiconTest(tf.test.TestCase):
lexicon.create_lexicon_context('/tmp'), expected_context) lexicon.create_lexicon_context('/tmp'), expected_context)
def testBuildLexicon(self): def testBuildLexicon(self):
empty_input_path = os.path.join(FLAGS.test_tmpdir, 'empty-input') empty_input_path = os.path.join(test_flags.temp_dir(), 'empty-input')
lexicon_output_path = os.path.join(FLAGS.test_tmpdir, 'lexicon-output') lexicon_output_path = os.path.join(test_flags.temp_dir(), 'lexicon-output')
with open(empty_input_path, 'w'): with open(empty_input_path, 'w'):
pass pass
......
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads mst_ops shared library."""
import os.path
import tensorflow as tf
tf.load_op_library(
os.path.join(tf.resource_loader.get_data_files_path(), 'mst_cc_impl.so'))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment