Unverified Commit 80178fc6 authored by Mark Omernick's avatar Mark Omernick Committed by GitHub
Browse files

Merge pull request #4153 from terryykoo/master

Export @195097388.
parents a84e1ef9 edea2b67
......@@ -16,7 +16,7 @@ message MasterPerformanceSettings {
// Maximum size of the free list in the SessionStatePool. NB: The default
// value may occasionally change.
optional uint64 session_state_pool_max_free_states = 1 [default = 4];
optional uint64 session_state_pool_max_free_states = 1 [default = 16];
}
// As above, but for component-specific performance tuning settings.
......
// DRAGNN Configuration proto. See go/dragnn-design for more information.
// DRAGNN Configuration proto.
syntax = "proto2";
......@@ -93,7 +94,7 @@ message Part {
// are extracted, embedded, and then concatenated together as a group.
// Specification for a feature channel that is a *fixed* function of the input.
// NEXT_ID: 10
// NEXT_ID: 12
message FixedFeatureChannel {
// Interpretable name for this feature channel. NN builders might depend on
// this to determine how to hook different channels up internally.
......@@ -129,6 +130,19 @@ message FixedFeatureChannel {
// Vocab file, containing all vocabulary words one per line.
optional Resource vocab = 8;
// Settings for feature ID dropout:
// If non-negative, enables feature ID dropout, and dropped feature IDs will
// be replaced with this ID.
optional int64 dropout_id = 10 [default = -1];
// Probability of keeping each of the |vocabulary_size| feature IDs. Only
// used if |dropout_id| is non-negative, and must not be empty in that case.
// If this has fewer than |vocabulary_size| values, then the final value is
// tiled onto the remaining IDs. For example, specifying a single value is
// equivalent to setting all IDs to that value.
repeated float dropout_keep_probability = 11 [packed = true];
}
// Specification for a feature channel that *links* to component
......@@ -173,11 +187,17 @@ message TrainingGridSpec {
}
// A hyperparameter configuration for a training run.
// NEXT ID: 22
// NEXT ID: 23
message GridPoint {
// Global learning rate initialization point.
optional double learning_rate = 1 [default = 0.1];
// Whether to use PBT (population-based training) to optimize the learning
// rate. Population-based training is not currently open-source, so this will
// just create a tf.assign op which external frameworks can use to adjust the
// learning rate.
optional bool pbt_optimize_learning_rate = 22 [default = false];
// Momentum coefficient when using MomentumOptimizer.
optional double momentum = 2 [default = 0.9];
......
......@@ -53,6 +53,8 @@ message ComponentStepTrace {
// Set to true once the step is finished. (This allows us to open a step after
// each transition, without having to know if it will be used.)
optional bool step_finished = 6 [default = false];
extensions 1000 to max;
}
// The traces for all steps for a single Component.
......
......@@ -16,6 +16,17 @@ cc_binary(
],
)
cc_binary(
name = "mst_cc_impl.so",
linkopts = select({
"//conditions:default": ["-lm"],
"@org_tensorflow//tensorflow:darwin": [],
}),
linkshared = 1,
linkstatic = 1,
deps = ["//dragnn/mst:mst_ops_cc"],
)
filegroup(
name = "testdata",
data = glob(["testdata/**"]),
......@@ -27,6 +38,12 @@ py_library(
data = [":dragnn_cc_impl.so"],
)
py_library(
name = "load_mst_cc_impl_py",
srcs = ["load_mst_cc_impl.py"],
data = [":mst_cc_impl.so"],
)
py_library(
name = "bulk_component",
srcs = [
......@@ -50,6 +67,8 @@ py_library(
":bulk_component",
":dragnn_ops",
":network_units",
":runtime_support",
"//dragnn/protos:export_pb2_py",
"//syntaxnet/util:check",
"//syntaxnet/util:pyregistry",
"@org_tensorflow//tensorflow:tensorflow_py",
......@@ -85,9 +104,9 @@ py_library(
":graph_builder",
":load_dragnn_cc_impl_py",
":network_units",
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:spec_pb2_py",
"//syntaxnet:load_parser_ops_py",
"//syntaxnet:sentence_py_pb2",
"//syntaxnet:sentence_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
......@@ -99,7 +118,9 @@ py_test(
data = [":testdata"],
deps = [
":dragnn_model_saver_lib",
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:export_pb2_py",
"//dragnn/protos:spec_pb2_py",
"//syntaxnet:test_flags",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
......@@ -110,7 +131,9 @@ py_binary(
deps = [
":dragnn_model_saver_lib",
":spec_builder",
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:spec_pb2_py",
"@absl_py//absl:app",
"@absl_py//absl/flags",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
......@@ -127,7 +150,7 @@ py_library(
":network_units",
":transformer_units",
":wrapped_units",
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:spec_pb2_py",
"//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
......@@ -159,7 +182,7 @@ py_test(
srcs = ["render_parse_tree_graphviz_test.py"],
deps = [
":render_parse_tree_graphviz",
"//syntaxnet:sentence_py_pb2",
"//syntaxnet:sentence_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
......@@ -168,7 +191,7 @@ py_library(
name = "render_spec_with_graphviz",
srcs = ["render_spec_with_graphviz.py"],
deps = [
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:spec_pb2_py",
],
)
......@@ -197,7 +220,7 @@ py_binary(
"//dragnn/viz:viz-min-js-gz",
],
deps = [
"//dragnn/protos:trace_py_pb2",
"//dragnn/protos:trace_pb2_py",
],
)
......@@ -206,8 +229,8 @@ py_test(
srcs = ["visualization_test.py"],
deps = [
":visualization",
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:trace_py_pb2",
"//dragnn/protos:spec_pb2_py",
"//dragnn/protos:trace_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
......@@ -225,6 +248,18 @@ py_library(
# Tests
py_test(
name = "component_test",
srcs = [
"component_test.py",
],
deps = [
":components",
"//dragnn/protos:spec_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_test(
name = "bulk_component_test",
srcs = [
......@@ -235,9 +270,9 @@ py_test(
":components",
":dragnn_ops",
":network_units",
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:spec_pb2_py",
"//syntaxnet:load_parser_ops_py",
"//syntaxnet:sentence_py_pb2",
"//syntaxnet:sentence_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
......@@ -270,10 +305,11 @@ py_test(
deps = [
":dragnn_ops",
":graph_builder",
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:trace_py_pb2",
"//dragnn/protos:spec_pb2_py",
"//dragnn/protos:trace_pb2_py",
"//syntaxnet:load_parser_ops_py",
"//syntaxnet:sentence_py_pb2",
"//syntaxnet:sentence_pb2_py",
"//syntaxnet:test_flags",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
......@@ -287,7 +323,7 @@ py_test(
":network_units",
"//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops",
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:spec_pb2_py",
"//syntaxnet:load_parser_ops_py",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
......@@ -303,7 +339,8 @@ py_test(
":sentence_io",
"//syntaxnet:load_parser_ops_py",
"//syntaxnet:parser_ops",
"//syntaxnet:sentence_py_pb2",
"//syntaxnet:sentence_pb2_py",
"//syntaxnet:test_flags",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
......@@ -313,21 +350,31 @@ py_library(
name = "trainer_lib",
srcs = ["trainer_lib.py"],
deps = [
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:spec_pb2_py",
"//syntaxnet:parser_ops",
"//syntaxnet:sentence_py_pb2",
"//syntaxnet:task_spec_py_pb2",
"//syntaxnet:sentence_pb2_py",
"//syntaxnet:task_spec_pb2_py",
"//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
)
py_test(
name = "trainer_lib_test",
srcs = ["trainer_lib_test.py"],
deps = [
":trainer_lib",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_library(
name = "lexicon",
srcs = ["lexicon.py"],
deps = [
"//syntaxnet:parser_ops",
"//syntaxnet:task_spec_py_pb2",
"//syntaxnet:task_spec_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
......@@ -340,6 +387,7 @@ py_test(
"//syntaxnet:load_parser_ops_py",
"//syntaxnet:parser_ops",
"//syntaxnet:parser_trainer",
"//syntaxnet:test_flags",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
......@@ -348,7 +396,7 @@ py_library(
name = "evaluation",
srcs = ["evaluation.py"],
deps = [
"//syntaxnet:sentence_py_pb2",
"//syntaxnet:sentence_pb2_py",
"//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py",
],
......@@ -359,7 +407,7 @@ py_test(
srcs = ["evaluation_test.py"],
deps = [
":evaluation",
"//syntaxnet:sentence_py_pb2",
"//syntaxnet:sentence_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
......@@ -369,7 +417,7 @@ py_library(
srcs = ["spec_builder.py"],
deps = [
":lexicon",
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:spec_pb2_py",
"//syntaxnet:parser_ops",
"//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py",
......@@ -381,7 +429,7 @@ py_test(
srcs = ["spec_builder_test.py"],
deps = [
":spec_builder",
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:spec_pb2_py",
"//syntaxnet:load_parser_ops_py",
"//syntaxnet:parser_ops",
"//syntaxnet:parser_trainer",
......@@ -418,6 +466,17 @@ py_library(
],
)
py_test(
name = "biaffine_units_test",
srcs = ["biaffine_units_test.py"],
deps = [
":biaffine_units",
":network_units",
"//dragnn/protos:spec_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_library(
name = "transformer_units",
srcs = ["transformer_units.py"],
......@@ -437,10 +496,85 @@ py_test(
":transformer_units",
"//dragnn/core:dragnn_bulk_ops",
"//dragnn/core:dragnn_ops",
"//dragnn/protos:spec_py_pb2",
"//dragnn/protos:spec_pb2_py",
"//syntaxnet:load_parser_ops_py",
"@org_tensorflow//tensorflow:tensorflow_py",
"@org_tensorflow//tensorflow/core:protos_all_py",
],
)
py_library(
name = "runtime_support",
srcs = ["runtime_support.py"],
deps = [
":network_units",
"//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_test(
name = "runtime_support_test",
srcs = ["runtime_support_test.py"],
deps = [
":network_units",
":runtime_support",
"//dragnn/protos:export_pb2_py",
"//dragnn/protos:spec_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_library(
name = "file_diff_test",
srcs = ["file_diff_test.py"],
deps = [
"@absl_py//absl/flags",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_library(
name = "mst_ops",
srcs = ["mst_ops.py"],
visibility = ["//visibility:public"],
deps = [
":digraph_ops",
":load_mst_cc_impl_py",
"//dragnn/mst:mst_ops",
"//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_test(
name = "mst_ops_test",
srcs = ["mst_ops_test.py"],
deps = [
":mst_ops",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_library(
name = "mst_units",
srcs = ["mst_units.py"],
deps = [
":mst_ops",
":network_units",
"//syntaxnet/util:check",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
py_test(
name = "mst_units_test",
size = "small",
srcs = ["mst_units_test.py"],
deps = [
":mst_units",
":network_units",
"//dragnn/protos:spec_pb2_py",
"@org_tensorflow//tensorflow:tensorflow_py",
],
)
......@@ -79,24 +79,44 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
self._source_dim = self._linked_feature_dims['sources']
self._target_dim = self._linked_feature_dims['targets']
# TODO(googleuser): Make parameter initialization configurable.
self._weights = []
self._weights.append(tf.get_variable(
'weights_arc', [self._source_dim, self._target_dim], tf.float32,
tf.random_normal_initializer(stddev=1e-4)))
self._weights.append(tf.get_variable(
'weights_source', [self._source_dim], tf.float32,
tf.random_normal_initializer(stddev=1e-4)))
self._weights.append(tf.get_variable(
'root', [self._source_dim], tf.float32,
tf.random_normal_initializer(stddev=1e-4)))
self._weights.append(
tf.get_variable('weights_arc', [self._source_dim, self._target_dim],
tf.float32, tf.orthogonal_initializer()))
self._weights.append(
tf.get_variable('weights_source', [self._source_dim], tf.float32,
tf.zeros_initializer()))
self._weights.append(
tf.get_variable('root', [self._source_dim], tf.float32,
tf.zeros_initializer()))
self._params.extend(self._weights)
self._regularized_weights.extend(self._weights)
# Add runtime hooks for pre-computed weights.
self._derived_params.append(self._get_root_weights)
self._derived_params.append(self._get_root_bias)
# Negative Layer.dim indicates that the dimension is dynamic.
self._layers.append(network_units.Layer(component, 'adjacency', -1))
def _get_root_weights(self):
"""Pre-computes the product of the root embedding and arc weights."""
weights_arc = self._component.get_variable('weights_arc')
root = self._component.get_variable('root')
name = self._component.name + '/root_weights'
with tf.name_scope(None):
return tf.matmul(tf.expand_dims(root, 0), weights_arc, name=name)
def _get_root_bias(self):
"""Pre-computes the product of the root embedding and source weights."""
weights_source = self._component.get_variable('weights_source')
root = self._component.get_variable('root')
name = self._component.name + '/root_bias'
with tf.name_scope(None):
return tf.matmul(
tf.expand_dims(root, 0), tf.expand_dims(weights_source, 1), name=name)
def create(self,
fixed_embeddings,
linked_embeddings,
......@@ -133,12 +153,17 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
sources_bxnxn = digraph_ops.ArcSourcePotentialsFromTokens(
source_tokens_bxnxs, weights_source)
roots_bxn = digraph_ops.RootPotentialsFromTokens(
root, target_tokens_bxnxt, weights_arc)
root, target_tokens_bxnxt, weights_arc, weights_source)
# Combine them into a single matrix with the roots on the diagonal.
adjacency_bxnxn = digraph_ops.CombineArcAndRootPotentials(
arcs_bxnxn + sources_bxnxn, roots_bxn)
# The adjacency matrix currently has sources on rows and targets on columns,
# but we want targets on rows so that maximizing within a row corresponds to
# selecting sources for a given target.
adjacency_bxnxn = tf.matrix_transpose(adjacency_bxnxn)
return [tf.reshape(adjacency_bxnxn, [-1, num_tokens])]
......
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for biaffine_units."""
import tensorflow as tf
from google.protobuf import text_format
from dragnn.protos import spec_pb2
from dragnn.python import biaffine_units
from dragnn.python import network_units
_BATCH_SIZE = 11
_NUM_TOKENS = 22
_TOKEN_DIM = 33
class MockNetwork(object):
def __init__(self):
pass
def get_layer_size(self, unused_name):
return _TOKEN_DIM
class MockComponent(object):
def __init__(self, master, component_spec):
self.master = master
self.spec = component_spec
self.name = component_spec.name
self.network = MockNetwork()
self.beam_size = 1
self.num_actions = 45
self._attrs = {}
def attr(self, name):
return self._attrs[name]
def get_variable(self, name):
return tf.get_variable(name)
class MockMaster(object):
def __init__(self):
self.spec = spec_pb2.MasterSpec()
self.hyperparams = spec_pb2.GridPoint()
self.lookup_component = {
'previous': MockComponent(self, spec_pb2.ComponentSpec())
}
def _make_biaffine_spec():
"""Returns a ComponentSpec that the BiaffineDigraphNetwork works on."""
component_spec = spec_pb2.ComponentSpec()
text_format.Parse("""
name: "test_component"
backend { registered_name: "TestComponent" }
linked_feature {
name: "sources"
fml: "input.focus"
source_translator: "identity"
source_component: "previous"
source_layer: "sources"
size: 1
embedding_dim: -1
}
linked_feature {
name: "targets"
fml: "input.focus"
source_translator: "identity"
source_component: "previous"
source_layer: "targets"
size: 1
embedding_dim: -1
}
network_unit {
registered_name: "biaffine_units.BiaffineDigraphNetwork"
}
""", component_spec)
return component_spec
class BiaffineDigraphNetworkTest(tf.test.TestCase):
def setUp(self):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf.reset_default_graph()
def testCanCreate(self):
"""Tests that create() works on a good spec."""
with tf.Graph().as_default(), self.test_session():
master = MockMaster()
component = MockComponent(master, _make_biaffine_spec())
with tf.variable_scope(component.name, reuse=None):
component.network = biaffine_units.BiaffineDigraphNetwork(component)
with tf.variable_scope(component.name, reuse=True):
sources = network_units.NamedTensor(
tf.zeros([_BATCH_SIZE * _NUM_TOKENS, _TOKEN_DIM]), 'sources')
targets = network_units.NamedTensor(
tf.zeros([_BATCH_SIZE * _NUM_TOKENS, _TOKEN_DIM]), 'targets')
# No assertions on the result, just don't crash.
component.network.create(
fixed_embeddings=[],
linked_embeddings=[sources, targets],
context_tensor_arrays=None,
attention_tensor=None,
during_training=True,
stride=_BATCH_SIZE)
def testDerivedParametersForRuntime(self):
"""Test generation of derived parameters for the runtime."""
with tf.Graph().as_default(), self.test_session():
master = MockMaster()
component = MockComponent(master, _make_biaffine_spec())
with tf.variable_scope(component.name, reuse=None):
component.network = biaffine_units.BiaffineDigraphNetwork(component)
with tf.variable_scope(component.name, reuse=True):
self.assertEqual(len(component.network.derived_params), 2)
root_weights = component.network.derived_params[0]()
root_bias = component.network.derived_params[1]()
# Only check shape; values are random due to initialization.
self.assertAllEqual(root_weights.shape.as_list(), [1, _TOKEN_DIM])
self.assertAllEqual(root_bias.shape.as_list(), [1, 1])
if __name__ == '__main__':
tf.test.main()
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Component builders for non-recurrent networks in DRAGNN."""
......@@ -51,10 +50,8 @@ def fetch_linked_embedding(comp, network_states, feature_spec):
feature_spec.name)
source = comp.master.lookup_component[feature_spec.source_component]
return network_units.NamedTensor(
network_states[source.name].activations[
feature_spec.source_layer].bulk_tensor,
feature_spec.name)
return network_units.NamedTensor(network_states[source.name].activations[
feature_spec.source_layer].bulk_tensor, feature_spec.name)
def _validate_embedded_fixed_features(comp):
......@@ -63,17 +60,20 @@ def _validate_embedded_fixed_features(comp):
check.Gt(feature.embedding_dim, 0,
'Embeddings requested for non-embedded feature: %s' % feature)
if feature.is_constant:
check.IsTrue(feature.HasField('pretrained_embedding_matrix'),
'Constant embeddings must be pretrained: %s' % feature)
check.IsTrue(
feature.HasField('pretrained_embedding_matrix'),
'Constant embeddings must be pretrained: %s' % feature)
def fetch_differentiable_fixed_embeddings(comp, state, stride):
def fetch_differentiable_fixed_embeddings(comp, state, stride, during_training):
"""Looks up fixed features with separate, differentiable, embedding lookup.
Args:
comp: Component whose fixed features we wish to look up.
state: live MasterState object for the component.
stride: Tensor containing current batch * beam size.
during_training: True if this is being called from a training code path.
This controls, e.g., the use of feature ID dropout.
Returns:
state handle: updated state handle to be used after this call
......@@ -93,6 +93,11 @@ def fetch_differentiable_fixed_embeddings(comp, state, stride):
'differentiable')
tf.logging.info('[%s] Adding %s fixed feature "%s"', comp.name,
differentiable_or_constant, feature_spec.name)
if during_training and feature_spec.dropout_id >= 0:
ids[channel], weights[channel] = network_units.apply_feature_id_dropout(
ids[channel], weights[channel], feature_spec)
size = stride * num_steps * feature_spec.size
fixed_embedding = network_units.embedding_lookup(
comp.get_variable(network_units.fixed_embeddings_name(channel)),
......@@ -105,16 +110,22 @@ def fetch_differentiable_fixed_embeddings(comp, state, stride):
return state.handle, fixed_embeddings
def fetch_fast_fixed_embeddings(comp, state):
def fetch_fast_fixed_embeddings(comp,
state,
pad_to_batch=None,
pad_to_steps=None):
"""Looks up fixed features with fast, non-differentiable, op.
Since BulkFixedEmbeddings is non-differentiable with respect to the
embeddings, the idea is to call this function only when the graph is
not being used for training.
not being used for training. If the function is being called with fixed step
and batch sizes, it will use the most efficient possible extractor.
Args:
comp: Component whose fixed features we wish to look up.
state: live MasterState object for the component.
pad_to_batch: Optional; the number of batch elements to pad to.
pad_to_steps: Optional; the number of steps to pad to.
Returns:
state handle: updated state handle to be used after this call
......@@ -126,19 +137,50 @@ def fetch_fast_fixed_embeddings(comp, state):
return state.handle, []
tf.logging.info('[%s] Adding %d fast fixed features', comp.name, num_channels)
state.handle, bulk_embeddings, _ = dragnn_ops.bulk_fixed_embeddings(
state.handle, [
comp.get_variable(network_units.fixed_embeddings_name(c))
for c in range(num_channels)
],
component=comp.name)
bulk_embeddings = network_units.NamedTensor(bulk_embeddings,
'bulk-%s-fixed-features' %
comp.name)
features = [
comp.get_variable(network_units.fixed_embeddings_name(c))
for c in range(num_channels)
]
if pad_to_batch is not None and pad_to_steps is not None:
# If we have fixed padding numbers, we can use 'bulk_embed_fixed_features',
# which is the fastest embedding extractor.
state.handle, bulk_embeddings, _ = dragnn_ops.bulk_embed_fixed_features(
state.handle,
features,
component=comp.name,
pad_to_batch=pad_to_batch,
pad_to_steps=pad_to_steps)
else:
state.handle, bulk_embeddings, _ = dragnn_ops.bulk_fixed_embeddings(
state.handle, features, component=comp.name)
bulk_embeddings = network_units.NamedTensor(
bulk_embeddings, 'bulk-%s-fixed-features' % comp.name)
return state.handle, [bulk_embeddings]
def fetch_dense_ragged_embeddings(comp, state):
"""Gets embeddings in RaggedTensor format."""
_validate_embedded_fixed_features(comp)
num_channels = len(comp.spec.fixed_feature)
if not num_channels:
return state.handle, []
tf.logging.info('[%s] Adding %d fast fixed features', comp.name, num_channels)
features = [
comp.get_variable(network_units.fixed_embeddings_name(c))
for c in range(num_channels)
]
state.handle, data, offsets = dragnn_ops.bulk_embed_dense_fixed_features(
state.handle, features, component=comp.name)
data = network_units.NamedTensor(data, 'dense-%s-data' % comp.name)
offsets = network_units.NamedTensor(offsets, 'dense-%s-offsets' % comp.name)
return state.handle, [data, offsets]
def extract_fixed_feature_ids(comp, state, stride):
"""Extracts fixed feature IDs.
......@@ -194,8 +236,10 @@ def update_network_states(comp, tensors, network_states, stride):
with tf.name_scope(comp.name + '/stored_act'):
for index, network_tensor in enumerate(tensors):
network_state.activations[comp.network.layers[index].name] = (
network_units.StoredActivations(tensor=network_tensor, stride=stride,
dim=comp.network.layers[index].dim))
network_units.StoredActivations(
tensor=network_tensor,
stride=stride,
dim=comp.network.layers[index].dim))
def build_cross_entropy_loss(logits, gold):
......@@ -205,7 +249,7 @@ def build_cross_entropy_loss(logits, gold):
Args:
logits: float Tensor of scores.
gold: int Tensor of one-hot labels.
gold: int Tensor of gold label ids.
Returns:
cost, correct, total: the total cost, the total number of correctly
......@@ -251,9 +295,10 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
"""
logging.info('Building component: %s', self.spec.name)
stride = state.current_batch_size * self.training_beam_size
self.network.pre_create(stride)
with tf.variable_scope(self.name, reuse=True):
state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings(
self, state, stride)
self, state, stride, True)
linked_embeddings = [
fetch_linked_embedding(self, network_states, spec)
......@@ -307,14 +352,29 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
stride = state.current_batch_size * self.training_beam_size
else:
stride = state.current_batch_size * self.inference_beam_size
self.network.pre_create(stride)
with tf.variable_scope(self.name, reuse=True):
if during_training:
state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings(
self, state, stride)
self, state, stride, during_training)
else:
state.handle, fixed_embeddings = fetch_fast_fixed_embeddings(self,
state)
if 'use_densors' in self.spec.network_unit.parameters:
state.handle, fixed_embeddings = fetch_dense_ragged_embeddings(
self, state)
else:
if ('padded_batch_size' in self.spec.network_unit.parameters and
'padded_sentence_length' in self.spec.network_unit.parameters):
state.handle, fixed_embeddings = fetch_fast_fixed_embeddings(
self,
state,
pad_to_batch=-1,
pad_to_steps=int(self.spec.network_unit.parameters[
'padded_sentence_length']))
else:
state.handle, fixed_embeddings = fetch_fast_fixed_embeddings(
self, state)
linked_embeddings = [
fetch_linked_embedding(self, network_states, spec)
......@@ -331,6 +391,7 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
stride=stride)
update_network_states(self, tensors, network_states, stride)
self._add_runtime_hooks()
return state.handle
......@@ -367,7 +428,9 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
def build_greedy_inference(self, state, network_states,
during_training=False):
"""See base class."""
return self._extract_feature_ids(state, network_states, during_training)
handle = self._extract_feature_ids(state, network_states, during_training)
self._add_runtime_hooks()
return handle
def _extract_feature_ids(self, state, network_states, during_training):
"""Extracts feature IDs and advances a batch using the oracle path.
......@@ -387,6 +450,7 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
stride = state.current_batch_size * self.training_beam_size
else:
stride = state.current_batch_size * self.inference_beam_size
self.network.pre_create(stride)
with tf.variable_scope(self.name, reuse=True):
state.handle, ids = extract_fixed_feature_ids(self, state, stride)
......@@ -438,17 +502,21 @@ class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
]
stride = state.current_batch_size * self.training_beam_size
self.network.pre_create(stride)
with tf.variable_scope(self.name, reuse=True):
network_tensors = self.network.create([], linked_embeddings, None, None,
True, stride)
update_network_states(self, network_tensors, network_states, stride)
logits = self.network.get_logits(network_tensors)
state.handle, gold = dragnn_ops.bulk_advance_from_oracle(
state.handle, component=self.name)
cost, correct, total = build_cross_entropy_loss(logits, gold)
cost, correct, total = self.network.compute_bulk_loss(
stride, network_tensors, gold)
if cost is None:
# The network does not have a custom bulk loss; default to softmax.
logits = self.network.get_logits(network_tensors)
cost, correct, total = build_cross_entropy_loss(logits, gold)
cost = self.add_regularizer(cost)
return state.handle, cost, correct, total
......@@ -483,13 +551,24 @@ class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
stride = state.current_batch_size * self.training_beam_size
else:
stride = state.current_batch_size * self.inference_beam_size
self.network.pre_create(stride)
with tf.variable_scope(self.name, reuse=True):
network_tensors = self.network.create(
[], linked_embeddings, None, None, during_training, stride)
network_tensors = self.network.create([], linked_embeddings, None, None,
during_training, stride)
update_network_states(self, network_tensors, network_states, stride)
logits = self.network.get_logits(network_tensors)
return dragnn_ops.bulk_advance_from_prediction(
logits = self.network.get_bulk_predictions(stride, network_tensors)
if logits is None:
# The network does not produce custom bulk predictions; default to logits.
logits = self.network.get_logits(network_tensors)
logits = tf.cond(self.locally_normalize,
lambda: tf.nn.log_softmax(logits), lambda: logits)
if self._output_as_probabilities:
logits = tf.nn.softmax(logits)
handle = dragnn_ops.bulk_advance_from_prediction(
state.handle, logits, component=self.name)
self._add_runtime_hooks()
return handle
......@@ -41,8 +41,6 @@ from dragnn.python import dragnn_ops
from dragnn.python import network_units
from syntaxnet import sentence_pb2
FLAGS = tf.app.flags.FLAGS
class MockNetworkUnit(object):
......@@ -63,6 +61,7 @@ class MockMaster(object):
self.spec = spec_pb2.MasterSpec()
self.hyperparams = spec_pb2.GridPoint()
self.lookup_component = {'mock': MockComponent()}
self.build_runtime_graph = False
def _create_fake_corpus():
......@@ -84,9 +83,12 @@ def _create_fake_corpus():
class BulkComponentTest(test_util.TensorFlowTestCase):
def setUp(self):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf.reset_default_graph()
self.master = MockMaster()
self.master_state = component.MasterState(
handle='handle', current_batch_size=2)
handle=tf.constant(['foo', 'bar']), current_batch_size=2)
self.network_states = {
'mock': component.NetworkState(),
'test': component.NetworkState(),
......@@ -107,22 +109,21 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
""", component_spec)
# For feature extraction:
with tf.Graph().as_default():
comp = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
comp = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
# Expect feature extraction to generate a error due to the "history"
# translator.
with self.assertRaises(NotImplementedError):
comp.build_greedy_training(self.master_state, self.network_states)
# Expect feature extraction to generate a error due to the "history"
# translator.
with self.assertRaises(NotImplementedError):
comp.build_greedy_training(self.master_state, self.network_states)
# As well as annotation:
with tf.Graph().as_default():
comp = bulk_component.BulkAnnotatorComponentBuilder(
self.master, component_spec)
self.setUp()
comp = bulk_component.BulkAnnotatorComponentBuilder(self.master,
component_spec)
with self.assertRaises(NotImplementedError):
comp.build_greedy_training(self.master_state, self.network_states)
with self.assertRaises(NotImplementedError):
comp.build_greedy_training(self.master_state, self.network_states)
def testFailsOnRecurrentLinkedFeature(self):
component_spec = spec_pb2.ComponentSpec()
......@@ -143,22 +144,21 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
""", component_spec)
# For feature extraction:
with tf.Graph().as_default():
comp = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
comp = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
# Expect feature extraction to generate a error due to the "history"
# translator.
with self.assertRaises(RuntimeError):
comp.build_greedy_training(self.master_state, self.network_states)
# Expect feature extraction to generate a error due to the "history"
# translator.
with self.assertRaises(RuntimeError):
comp.build_greedy_training(self.master_state, self.network_states)
# As well as annotation:
with tf.Graph().as_default():
comp = bulk_component.BulkAnnotatorComponentBuilder(
self.master, component_spec)
self.setUp()
comp = bulk_component.BulkAnnotatorComponentBuilder(self.master,
component_spec)
with self.assertRaises(RuntimeError):
comp.build_greedy_training(self.master_state, self.network_states)
with self.assertRaises(RuntimeError):
comp.build_greedy_training(self.master_state, self.network_states)
def testConstantFixedFeatureFailsIfNotPretrained(self):
component_spec = spec_pb2.ComponentSpec()
......@@ -175,21 +175,20 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
}
""", component_spec)
with tf.Graph().as_default():
comp = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
comp = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
with self.assertRaisesRegexp(ValueError,
'Constant embeddings must be pretrained'):
comp.build_greedy_training(self.master_state, self.network_states)
with self.assertRaisesRegexp(ValueError,
'Constant embeddings must be pretrained'):
comp.build_greedy_inference(
self.master_state, self.network_states, during_training=True)
with self.assertRaisesRegexp(ValueError,
'Constant embeddings must be pretrained'):
comp.build_greedy_inference(
self.master_state, self.network_states, during_training=False)
with self.assertRaisesRegexp(ValueError,
'Constant embeddings must be pretrained'):
comp.build_greedy_training(self.master_state, self.network_states)
with self.assertRaisesRegexp(ValueError,
'Constant embeddings must be pretrained'):
comp.build_greedy_inference(
self.master_state, self.network_states, during_training=True)
with self.assertRaisesRegexp(ValueError,
'Constant embeddings must be pretrained'):
comp.build_greedy_inference(
self.master_state, self.network_states, during_training=False)
def testNormalFixedFeaturesAreDifferentiable(self):
component_spec = spec_pb2.ComponentSpec()
......@@ -207,25 +206,24 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
}
""", component_spec)
with tf.Graph().as_default():
comp = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
comp = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
# Get embedding matrix variables.
with tf.variable_scope(comp.name, reuse=True):
fixed_embedding_matrix = tf.get_variable(
network_units.fixed_embeddings_name(0))
# Get embedding matrix variables.
with tf.variable_scope(comp.name, reuse=True):
fixed_embedding_matrix = tf.get_variable(
network_units.fixed_embeddings_name(0))
# Get output layer.
comp.build_greedy_training(self.master_state, self.network_states)
activations = self.network_states[comp.name].activations
outputs = activations[comp.network.layers[0].name].bulk_tensor
# Get output layer.
comp.build_greedy_training(self.master_state, self.network_states)
activations = self.network_states[comp.name].activations
outputs = activations[comp.network.layers[0].name].bulk_tensor
# Compute the gradient of the output layer w.r.t. the embedding matrix.
# This should be well-defined for in the normal case.
gradients = tf.gradients(outputs, fixed_embedding_matrix)
self.assertEqual(len(gradients), 1)
self.assertFalse(gradients[0] is None)
# Compute the gradient of the output layer w.r.t. the embedding matrix.
# This should be well-defined for in the normal case.
gradients = tf.gradients(outputs, fixed_embedding_matrix)
self.assertEqual(len(gradients), 1)
self.assertFalse(gradients[0] is None)
def testConstantFixedFeaturesAreNotDifferentiableButOthersAre(self):
component_spec = spec_pb2.ComponentSpec()
......@@ -249,31 +247,30 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
registered_name: "bulk_component.BulkFeatureExtractorComponentBuilder"
}
""", component_spec)
with tf.Graph().as_default():
comp = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
# Get embedding matrix variables.
with tf.variable_scope(comp.name, reuse=True):
constant_embedding_matrix = tf.get_variable(
network_units.fixed_embeddings_name(0))
trainable_embedding_matrix = tf.get_variable(
network_units.fixed_embeddings_name(1))
# Get output layer.
comp.build_greedy_training(self.master_state, self.network_states)
activations = self.network_states[comp.name].activations
outputs = activations[comp.network.layers[0].name].bulk_tensor
# The constant embeddings are non-differentiable.
constant_gradients = tf.gradients(outputs, constant_embedding_matrix)
self.assertEqual(len(constant_gradients), 1)
self.assertTrue(constant_gradients[0] is None)
# The trainable embeddings are differentiable.
trainable_gradients = tf.gradients(outputs, trainable_embedding_matrix)
self.assertEqual(len(trainable_gradients), 1)
self.assertFalse(trainable_gradients[0] is None)
comp = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
# Get embedding matrix variables.
with tf.variable_scope(comp.name, reuse=True):
constant_embedding_matrix = tf.get_variable(
network_units.fixed_embeddings_name(0))
trainable_embedding_matrix = tf.get_variable(
network_units.fixed_embeddings_name(1))
# Get output layer.
comp.build_greedy_training(self.master_state, self.network_states)
activations = self.network_states[comp.name].activations
outputs = activations[comp.network.layers[0].name].bulk_tensor
# The constant embeddings are non-differentiable.
constant_gradients = tf.gradients(outputs, constant_embedding_matrix)
self.assertEqual(len(constant_gradients), 1)
self.assertTrue(constant_gradients[0] is None)
# The trainable embeddings are differentiable.
trainable_gradients = tf.gradients(outputs, trainable_embedding_matrix)
self.assertEqual(len(trainable_gradients), 1)
self.assertFalse(trainable_gradients[0] is None)
def testFailsOnFixedFeature(self):
component_spec = spec_pb2.ComponentSpec()
......@@ -306,15 +303,14 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
name: "fixed" embedding_dim: -1 size: 1
}
""", component_spec)
with tf.Graph().as_default():
comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, component_spec)
comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, component_spec)
# Should not raise errors.
self.network_states[component_spec.name] = component.NetworkState()
comp.build_greedy_training(self.master_state, self.network_states)
self.network_states[component_spec.name] = component.NetworkState()
comp.build_greedy_inference(self.master_state, self.network_states)
# Should not raise errors.
self.network_states[component_spec.name] = component.NetworkState()
comp.build_greedy_training(self.master_state, self.network_states)
self.network_states[component_spec.name] = component.NetworkState()
comp.build_greedy_inference(self.master_state, self.network_states)
def testBulkFeatureIdExtractorFailsOnLinkedFeature(self):
component_spec = spec_pb2.ComponentSpec()
......@@ -332,10 +328,9 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
source_component: "mock"
}
""", component_spec)
with tf.Graph().as_default():
with self.assertRaises(ValueError):
unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, component_spec)
with self.assertRaises(ValueError):
unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, component_spec)
def testBulkFeatureIdExtractorOkWithMultipleFixedFeatures(self):
component_spec = spec_pb2.ComponentSpec()
......@@ -354,15 +349,14 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
name: "fixed3" embedding_dim: -1 size: 1
}
""", component_spec)
with tf.Graph().as_default():
comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, component_spec)
comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, component_spec)
# Should not raise errors.
self.network_states[component_spec.name] = component.NetworkState()
comp.build_greedy_training(self.master_state, self.network_states)
self.network_states[component_spec.name] = component.NetworkState()
comp.build_greedy_inference(self.master_state, self.network_states)
# Should not raise errors.
self.network_states[component_spec.name] = component.NetworkState()
comp.build_greedy_training(self.master_state, self.network_states)
self.network_states[component_spec.name] = component.NetworkState()
comp.build_greedy_inference(self.master_state, self.network_states)
def testBulkFeatureIdExtractorFailsOnEmbeddedFixedFeature(self):
component_spec = spec_pb2.ComponentSpec()
......@@ -375,10 +369,9 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
name: "fixed" embedding_dim: 2 size: 1
}
""", component_spec)
with tf.Graph().as_default():
with self.assertRaises(ValueError):
unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, component_spec)
with self.assertRaises(ValueError):
unused_comp = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, component_spec)
def testBulkFeatureIdExtractorExtractFocusWithOffset(self):
path = os.path.join(tf.test.get_temp_dir(), 'label-map')
......@@ -420,67 +413,131 @@ class BulkComponentTest(test_util.TensorFlowTestCase):
}
""" % path, master_spec)
with tf.Graph().as_default():
corpus = _create_fake_corpus()
corpus = tf.constant(corpus, shape=[len(corpus)])
handle = dragnn_ops.get_session(
container='test',
master_spec=master_spec.SerializeToString(),
grid_point='')
handle = dragnn_ops.attach_data_reader(handle, corpus)
handle = dragnn_ops.init_component_data(
handle, beam_size=1, component='test')
batch_size = dragnn_ops.batch_size(handle, component='test')
master_state = component.MasterState(handle, batch_size)
extractor = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, master_spec.component[0])
network_state = component.NetworkState()
self.network_states['test'] = network_state
handle = extractor.build_greedy_inference(master_state,
self.network_states)
focus1 = network_state.activations['focus1'].bulk_tensor
focus2 = network_state.activations['focus2'].bulk_tensor
focus3 = network_state.activations['focus3'].bulk_tensor
with self.test_session() as sess:
focus1, focus2, focus3 = sess.run([focus1, focus2, focus3])
tf.logging.info('focus1=\n%s', focus1)
tf.logging.info('focus2=\n%s', focus2)
tf.logging.info('focus3=\n%s', focus3)
self.assertAllEqual(
focus1,
[[0], [-1], [-1], [-1],
[0], [1], [-1], [-1],
[0], [1], [2], [-1],
[0], [1], [2], [3]])
self.assertAllEqual(
focus2,
[[-1], [-1], [-1], [-1],
[1], [-1], [-1], [-1],
[1], [2], [-1], [-1],
[1], [2], [3], [-1]])
self.assertAllEqual(
focus3,
[[-1], [-1], [-1], [-1],
[-1], [-1], [-1], [-1],
[2], [-1], [-1], [-1],
[2], [3], [-1], [-1]])
corpus = _create_fake_corpus()
corpus = tf.constant(corpus, shape=[len(corpus)])
handle = dragnn_ops.get_session(
container='test',
master_spec=master_spec.SerializeToString(),
grid_point='')
handle = dragnn_ops.attach_data_reader(handle, corpus)
handle = dragnn_ops.init_component_data(
handle, beam_size=1, component='test')
batch_size = dragnn_ops.batch_size(handle, component='test')
master_state = component.MasterState(handle, batch_size)
extractor = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, master_spec.component[0])
network_state = component.NetworkState()
self.network_states['test'] = network_state
handle = extractor.build_greedy_inference(master_state, self.network_states)
focus1 = network_state.activations['focus1'].bulk_tensor
focus2 = network_state.activations['focus2'].bulk_tensor
focus3 = network_state.activations['focus3'].bulk_tensor
with self.test_session() as sess:
focus1, focus2, focus3 = sess.run([focus1, focus2, focus3])
tf.logging.info('focus1=\n%s', focus1)
tf.logging.info('focus2=\n%s', focus2)
tf.logging.info('focus3=\n%s', focus3)
self.assertAllEqual(focus1,
[[0], [-1], [-1], [-1],
[0], [1], [-1], [-1],
[0], [1], [2], [-1],
[0], [1], [2], [3]]) # pyformat: disable
self.assertAllEqual(focus2,
[[-1], [-1], [-1], [-1],
[1], [-1], [-1], [-1],
[1], [2], [-1], [-1],
[1], [2], [3], [-1]]) # pyformat: disable
self.assertAllEqual(focus3,
[[-1], [-1], [-1], [-1],
[-1], [-1], [-1], [-1],
[2], [-1], [-1], [-1],
[2], [3], [-1], [-1]]) # pyformat: disable
def testBuildLossFailsOnNoExamples(self):
with tf.Graph().as_default():
logits = tf.constant([[0.5], [-0.5], [0.5], [-0.5]])
gold = tf.constant([-1, -1, -1, -1])
result = bulk_component.build_cross_entropy_loss(logits, gold)
# Expect loss computation to generate a runtime error due to the gold
# tensor containing no valid examples.
with self.test_session() as sess:
with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run(result)
logits = tf.constant([[0.5], [-0.5], [0.5], [-0.5]])
gold = tf.constant([-1, -1, -1, -1])
result = bulk_component.build_cross_entropy_loss(logits, gold)
# Expect loss computation to generate a runtime error due to the gold
# tensor containing no valid examples.
with self.test_session() as sess:
with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run(result)
def testPreCreateCalledBeforeCreate(self):
component_spec = spec_pb2.ComponentSpec()
text_format.Parse("""
name: "test"
network_unit {
registered_name: "IdentityNetwork"
}
""", component_spec)
class AssertPreCreateBeforeCreateNetwork(
network_units.NetworkUnitInterface):
"""Mock that asserts that .create() is called before .pre_create()."""
def __init__(self, comp, test_fixture):
super(AssertPreCreateBeforeCreateNetwork, self).__init__(comp)
self._test_fixture = test_fixture
self._pre_create_called = False
def get_logits(self, network_tensors):
return tf.zeros([2, 1], dtype=tf.float32)
def pre_create(self, *unused_args):
self._pre_create_called = True
def create(self, *unused_args, **unuesd_kwargs):
self._test_fixture.assertTrue(self._pre_create_called)
return []
builder = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
builder.build_greedy_training(
component.MasterState(['foo', 'bar'], 2), self.network_states)
self.setUp()
builder = bulk_component.BulkFeatureExtractorComponentBuilder(
self.master, component_spec)
builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
builder.build_greedy_inference(
component.MasterState(['foo', 'bar'], 2), self.network_states)
self.setUp()
builder = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, component_spec)
builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
builder.build_greedy_training(
component.MasterState(['foo', 'bar'], 2), self.network_states)
self.setUp()
builder = bulk_component.BulkFeatureIdExtractorComponentBuilder(
self.master, component_spec)
builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
builder.build_greedy_inference(
component.MasterState(['foo', 'bar'], 2), self.network_states)
self.setUp()
builder = bulk_component.BulkAnnotatorComponentBuilder(
self.master, component_spec)
builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
builder.build_greedy_training(
component.MasterState(['foo', 'bar'], 2), self.network_states)
self.setUp()
builder = bulk_component.BulkAnnotatorComponentBuilder(
self.master, component_spec)
builder.network = AssertPreCreateBeforeCreateNetwork(builder, self)
builder.build_greedy_inference(
component.MasterState(['foo', 'bar'], 2), self.network_states)
if __name__ == '__main__':
googletest.main()
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds a DRAGNN graph for local training."""
from abc import ABCMeta
......@@ -21,12 +20,79 @@ from abc import abstractmethod
import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging
from dragnn.protos import export_pb2
from dragnn.python import dragnn_ops
from dragnn.python import network_units
from dragnn.python import runtime_support
from syntaxnet.util import check
from syntaxnet.util import registry
def build_softmax_cross_entropy_loss(logits, gold):
"""Builds softmax cross entropy loss."""
# A gold label > -1 determines that the sentence is still
# in a valid state. Otherwise, the sentence has ended.
#
# We add only the valid sentences to the loss, in the following way:
# 1. We compute 'valid_ix', the indices in gold that contain
# valid oracle actions.
# 2. We compute the cost function by comparing logits and gold
# only for the valid indices.
valid = tf.greater(gold, -1)
valid_ix = tf.reshape(tf.where(valid), [-1])
valid_gold = tf.gather(gold, valid_ix)
valid_logits = tf.gather(logits, valid_ix)
cost = tf.reduce_sum(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.cast(valid_gold, tf.int64),
logits=valid_logits,
name='sparse_softmax_cross_entropy_with_logits'))
correct = tf.reduce_sum(
tf.to_int32(tf.nn.in_top_k(valid_logits, valid_gold, 1)))
total = tf.size(valid_gold)
return cost, correct, total, valid_logits, valid_gold
def build_sigmoid_cross_entropy_loss(logits, gold, indices, probs):
"""Builds sigmoid cross entropy loss."""
# Filter out entries where gold <= -1, which are batch padding entries.
valid = tf.greater(gold, -1)
valid_ix = tf.reshape(tf.where(valid), [-1])
valid_gold = tf.gather(gold, valid_ix)
valid_indices = tf.gather(indices, valid_ix)
valid_probs = tf.gather(probs, valid_ix)
# NB: tf.gather_nd() raises an error on CPU for out-of-bounds indices. That's
# why we need to filter out the gold=-1 batch padding above.
valid_pairs = tf.stack([valid_indices, valid_gold], axis=1)
valid_logits = tf.gather_nd(logits, valid_pairs)
cost = tf.reduce_sum(
tf.nn.sigmoid_cross_entropy_with_logits(
labels=valid_probs,
logits=valid_logits,
name='sigmoid_cross_entropy_with_logits'))
gold_bool = valid_probs > 0.5
predicted_bool = valid_logits > 0.0
total = tf.size(gold_bool)
with tf.control_dependencies([
tf.assert_equal(
total, tf.size(predicted_bool), name='num_predicted_gold_mismatch')
]):
agreement_bool = tf.logical_not(tf.logical_xor(gold_bool, predicted_bool))
correct = tf.reduce_sum(tf.to_int32(agreement_bool))
cost.set_shape([])
correct.set_shape([])
total.set_shape([])
return cost, correct, total, gold
class NetworkState(object):
"""Simple utility to manage the state of a DRAGNN network.
......@@ -69,6 +135,13 @@ class ComponentBuilderBase(object):
As part of the specification, ComponentBuilder will wrap an underlying
NetworkUnit which generates the actual network layout.
Attributes:
master: dragnn.MasterBuilder that owns this component.
num_actions: Number of actions in the transition system.
name: Name of this component.
spec: dragnn.ComponentSpec that configures this component.
moving_average: True if moving-average parameters should be used.
"""
__metaclass__ = ABCMeta # required for @abstractmethod
......@@ -96,16 +169,23 @@ class ComponentBuilderBase(object):
# Extract component attributes before make_network(), so the network unit
# can access them.
self._attrs = {}
global_attr_defaults = {
'locally_normalize': False,
'output_as_probabilities': False
}
if attr_defaults:
self._attrs = network_units.get_attrs_with_defaults(
self.spec.component_builder.parameters, attr_defaults)
global_attr_defaults.update(attr_defaults)
self._attrs = network_units.get_attrs_with_defaults(
self.spec.component_builder.parameters, global_attr_defaults)
do_local_norm = self._attrs['locally_normalize']
self._output_as_probabilities = self._attrs['output_as_probabilities']
with tf.variable_scope(self.name):
self.training_beam_size = tf.constant(
self.spec.training_beam_size, name='TrainingBeamSize')
self.inference_beam_size = tf.constant(
self.spec.inference_beam_size, name='InferenceBeamSize')
self.locally_normalize = tf.constant(False, name='LocallyNormalize')
self.locally_normalize = tf.constant(
do_local_norm, name='LocallyNormalize')
self._step = tf.get_variable(
'step', [], initializer=tf.zeros_initializer(), dtype=tf.int32)
self._total = tf.get_variable(
......@@ -120,6 +200,9 @@ class ComponentBuilderBase(object):
decay=self.master.hyperparams.average_weight, num_updates=self._step)
self.avg_ops = [self.moving_average.apply(self.network.params)]
# Used to export the cell; see add_cell_input() and add_cell_output().
self._cell_subgraph_spec = export_pb2.CellSubgraphSpec()
def make_network(self, network_unit):
"""Makes a NetworkUnitInterface object based on the network_unit spec.
......@@ -276,7 +359,7 @@ class ComponentBuilderBase(object):
Returns:
tf.Variable object corresponding to original or averaged version.
"""
if var_params:
if var_params is not None:
var_name = var_params.name
else:
check.NotNone(var_name, 'specify at least one of var_name or var_params')
......@@ -341,6 +424,79 @@ class ComponentBuilderBase(object):
"""Returns the value of the component attribute with the |name|."""
return self._attrs[name]
def has_attr(self, name):
"""Checks whether a component attribute with the given |name| exists.
Arguments:
name: attribute name
Returns:
'true' if the name exists and 'false' otherwise.
"""
return name in self._attrs
def _add_runtime_hooks(self):
"""Adds "hook" nodes to the graph for use by the runtime, if enabled.
Does nothing if master.build_runtime_graph is False. Subclasses should call
this at the end of build_*_inference(). For details on the runtime hooks,
see runtime_support.py.
"""
if self.master.build_runtime_graph:
with tf.variable_scope(self.name, reuse=True):
runtime_support.add_hooks(self, self._cell_subgraph_spec)
self._cell_subgraph_spec = None # prevent further exports
def add_cell_input(self, dtype, shape, name, input_type='TYPE_FEATURE'):
"""Adds an input to the current CellSubgraphSpec.
Creates a tf.placeholder() with the given |dtype| and |shape|, adds it as a
cell input with the |name| and |input_type|, and returns the placeholder to
be used in place of the actual input tensor.
Args:
dtype: DType of the cell input.
shape: Static shape of the cell input.
name: Logical name of the cell input.
input_type: Name of the appropriate CellSubgraphSpec.Input.Type enum.
Returns:
A tensor to use in place of the actual input tensor.
Raises:
TypeError: If the |shape| is the wrong type.
RuntimeError: If the cell has already been exported.
"""
if not (isinstance(shape, list) and
all(isinstance(dim, int) for dim in shape)):
raise TypeError('shape must be a list of int')
if not self._cell_subgraph_spec:
raise RuntimeError('already exported a CellSubgraphSpec')
with tf.name_scope(None):
tensor = tf.placeholder(
dtype, shape, name='{}/INPUT/{}'.format(self.name, name))
self._cell_subgraph_spec.input.add(
name=name,
tensor=tensor.name,
type=export_pb2.CellSubgraphSpec.Input.Type.Value(input_type))
return tensor
def add_cell_output(self, tensor, name):
"""Adds an output to the current CellSubgraphSpec.
Args:
tensor: Tensor to add as a cell output.
name: Logical name of the cell output.
Raises:
RuntimeError: If the cell has already been exported.
"""
if not self._cell_subgraph_spec:
raise RuntimeError('already exported a CellSubgraphSpec')
self._cell_subgraph_spec.output.add(name=name, tensor=tensor.name)
def update_tensor_arrays(network_tensors, arrays):
"""Updates a list of tensor arrays from the network's output tensors.
......@@ -370,6 +526,18 @@ class DynamicComponentBuilder(ComponentBuilderBase):
so fixed and linked features can be recurrent.
"""
def __init__(self, master, component_spec):
"""Initializes the DynamicComponentBuilder from specifications.
Args:
master: dragnn.MasterBuilder object.
component_spec: dragnn.ComponentSpec proto to be built.
"""
super(DynamicComponentBuilder, self).__init__(
master,
component_spec,
attr_defaults={'loss_function': 'softmax_cross_entropy'})
def build_greedy_training(self, state, network_states):
"""Builds a training loop for this component.
......@@ -392,9 +560,10 @@ class DynamicComponentBuilder(ComponentBuilderBase):
# Add 0 to training_beam_size to disable eager static evaluation.
# This is possible because tensorflow's constant_value does not
# propagate arithmetic operations.
with tf.control_dependencies([
tf.assert_equal(self.training_beam_size + 0, 1)]):
with tf.control_dependencies(
[tf.assert_equal(self.training_beam_size + 0, 1)]):
stride = state.current_batch_size * self.training_beam_size
self.network.pre_create(stride)
cost = tf.constant(0.)
correct = tf.constant(0)
......@@ -416,40 +585,35 @@ class DynamicComponentBuilder(ComponentBuilderBase):
# Every layer is written to a TensorArray, so that it can be backprop'd.
next_arrays = update_tensor_arrays(network_tensors, arrays)
loss_function = self.attr('loss_function')
with tf.control_dependencies([x.flow for x in next_arrays]):
with tf.name_scope('compute_loss'):
# A gold label > -1 determines that the sentence is still
# in a valid state. Otherwise, the sentence has ended.
#
# We add only the valid sentences to the loss, in the following way:
# 1. We compute 'valid_ix', the indices in gold that contain
# valid oracle actions.
# 2. We compute the cost function by comparing logits and gold
# only for the valid indices.
gold = dragnn_ops.emit_oracle_labels(handle, component=self.name)
gold.set_shape([None])
valid = tf.greater(gold, -1)
valid_ix = tf.reshape(tf.where(valid), [-1])
gold = tf.gather(gold, valid_ix)
logits = self.network.get_logits(network_tensors)
logits = tf.gather(logits, valid_ix)
cost += tf.reduce_sum(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.cast(gold, tf.int64), logits=logits))
if (self.eligible_for_self_norm and
self.master.hyperparams.self_norm_alpha > 0):
log_z = tf.reduce_logsumexp(logits, [1])
cost += (self.master.hyperparams.self_norm_alpha *
tf.nn.l2_loss(log_z))
correct += tf.reduce_sum(
tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
total += tf.size(gold)
with tf.control_dependencies([cost, correct, total, gold]):
if loss_function == 'softmax_cross_entropy':
gold = dragnn_ops.emit_oracle_labels(handle, component=self.name)
new_cost, new_correct, new_total, valid_logits, valid_gold = (
build_softmax_cross_entropy_loss(logits, gold))
if (self.eligible_for_self_norm and
self.master.hyperparams.self_norm_alpha > 0):
log_z = tf.reduce_logsumexp(valid_logits, [1])
new_cost += (self.master.hyperparams.self_norm_alpha *
tf.nn.l2_loss(log_z))
elif loss_function == 'sigmoid_cross_entropy':
indices, gold, probs = (
dragnn_ops.emit_oracle_labels_and_probabilities(
handle, component=self.name))
new_cost, new_correct, new_total, valid_gold = (
build_sigmoid_cross_entropy_loss(logits, gold, indices,
probs))
else:
RuntimeError("Unknown loss function '%s'" % loss_function)
cost += new_cost
correct += new_correct
total += new_total
with tf.control_dependencies([cost, correct, total, valid_gold]):
handle = dragnn_ops.advance_from_oracle(handle, component=self.name)
return [handle, cost, correct, total] + next_arrays
......@@ -480,6 +644,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
# Normalize the objective by the total # of steps taken.
# Note: Total could be zero by a number of reasons, including:
# * Oracle labels not being emitted.
# * All oracle labels for a batch are unknown (-1).
# * No steps being taken if component is terminal at the start of a batch.
with tf.control_dependencies([tf.assert_greater(total, 0)]):
cost /= tf.to_float(total)
......@@ -511,6 +676,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
stride = state.current_batch_size * self.training_beam_size
else:
stride = state.current_batch_size * self.inference_beam_size
self.network.pre_create(stride)
def cond(handle, *_):
all_final = dragnn_ops.emit_all_final(handle, component=self.name)
......@@ -559,6 +725,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
for index, layer in enumerate(self.network.layers):
network_state.activations[layer.name] = network_units.StoredActivations(
array=arrays[index])
self._add_runtime_hooks()
with tf.control_dependencies([x.flow for x in arrays]):
return tf.identity(state.handle)
......@@ -587,7 +754,7 @@ class DynamicComponentBuilder(ComponentBuilderBase):
fixed_embeddings = []
for channel_id, feature_spec in enumerate(self.spec.fixed_feature):
fixed_embedding = network_units.fixed_feature_lookup(
self, state, channel_id, stride)
self, state, channel_id, stride, during_training)
if feature_spec.is_constant:
fixed_embedding.tensor = tf.stop_gradient(fixed_embedding.tensor)
fixed_embeddings.append(fixed_embedding)
......@@ -633,6 +800,12 @@ class DynamicComponentBuilder(ComponentBuilderBase):
else:
attention_tensor = None
return self.network.create(fixed_embeddings, linked_embeddings,
context_tensor_arrays, attention_tensor,
during_training)
tensors = self.network.create(fixed_embeddings, linked_embeddings,
context_tensor_arrays, attention_tensor,
during_training)
if self.master.build_runtime_graph:
for index, layer in enumerate(self.network.layers):
self.add_cell_output(tensors[index], layer.name)
return tensors
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for component.py.
"""
import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from google.protobuf import text_format
from dragnn.protos import spec_pb2
from dragnn.python import component
class MockNetworkUnit(object):
def get_layer_size(self, unused_layer_name):
return 64
class MockComponent(object):
def __init__(self):
self.name = 'mock'
self.network = MockNetworkUnit()
class MockMaster(object):
def __init__(self):
self.spec = spec_pb2.MasterSpec()
self.hyperparams = spec_pb2.GridPoint()
self.lookup_component = {'mock': MockComponent()}
self.build_runtime_graph = False
class ComponentTest(test_util.TensorFlowTestCase):
def setUp(self):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf.reset_default_graph()
self.master = MockMaster()
self.master_state = component.MasterState(
handle=tf.constant(['foo', 'bar']), current_batch_size=2)
self.network_states = {
'mock': component.NetworkState(),
'test': component.NetworkState(),
}
def testSoftmaxCrossEntropyLoss(self):
logits = tf.constant([[0.0, 2.0, -1.0],
[-5.0, 1.0, -1.0],
[3.0, 1.0, -2.0]]) # pyformat: disable
gold_labels = tf.constant([1, -1, 1])
cost, correct, total, logits, gold_labels = (
component.build_softmax_cross_entropy_loss(logits, gold_labels))
with self.test_session() as sess:
cost, correct, total, logits, gold_labels = (
sess.run([cost, correct, total, logits, gold_labels]))
# Cost = -2 + ln(1 + exp(2) + exp(-1))
# -1 + ln(exp(3) + exp(1) + exp(-2))
self.assertAlmostEqual(cost, 2.3027, 4)
self.assertEqual(correct, 1)
self.assertEqual(total, 2)
# Entries corresponding to gold labels equal to -1 are skipped.
self.assertAllEqual(logits, [[0.0, 2.0, -1.0], [3.0, 1.0, -2.0]])
self.assertAllEqual(gold_labels, [1, 1])
def testSigmoidCrossEntropyLoss(self):
indices = tf.constant([0, 0, 1])
gold_labels = tf.constant([0, 1, 2])
probs = tf.constant([0.6, 0.7, 0.2])
logits = tf.constant([[0.9, -0.3, 0.1], [-0.5, 0.4, 2.0]])
cost, correct, total, gold_labels = (
component.build_sigmoid_cross_entropy_loss(logits, gold_labels, indices,
probs))
with self.test_session() as sess:
cost, correct, total, gold_labels = (
sess.run([cost, correct, total, gold_labels]))
# The cost corresponding to the three entries is, respectively,
# 0.7012, 0.7644, and 1.7269. Each of them is computed using the formula
# -prob_i * log(sigmoid(logit_i)) - (1-prob_i) * log(1-sigmoid(logit_i))
self.assertAlmostEqual(cost, 3.1924, 4)
self.assertEqual(correct, 1)
self.assertEqual(total, 3)
self.assertAllEqual(gold_labels, [0, 1, 2])
def testGraphConstruction(self):
component_spec = spec_pb2.ComponentSpec()
text_format.Parse("""
name: "test"
network_unit {
registered_name: "IdentityNetwork"
}
fixed_feature {
name: "fixed" embedding_dim: 32 size: 1
}
component_builder {
registered_name: "component.DynamicComponentBuilder"
}
""", component_spec)
comp = component.DynamicComponentBuilder(self.master, component_spec)
comp.build_greedy_training(self.master_state, self.network_states)
def testGraphConstructionWithSigmoidLoss(self):
component_spec = spec_pb2.ComponentSpec()
text_format.Parse("""
name: "test"
network_unit {
registered_name: "IdentityNetwork"
}
fixed_feature {
name: "fixed" embedding_dim: 32 size: 1
}
component_builder {
registered_name: "component.DynamicComponentBuilder"
parameters {
key: "loss_function"
value: "sigmoid_cross_entropy"
}
}
""", component_spec)
comp = component.DynamicComponentBuilder(self.master, component_spec)
comp.build_greedy_training(self.master_state, self.network_states)
# Check that the loss op is present.
op_names = [op.name for op in tf.get_default_graph().get_operations()]
self.assertTrue('train_test/compute_loss/'
'sigmoid_cross_entropy_with_logits' in op_names)
if __name__ == '__main__':
googletest.main()
......@@ -15,6 +15,10 @@
"""TensorFlow ops for directed graphs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from syntaxnet.util import check
......@@ -150,7 +154,7 @@ def ArcSourcePotentialsFromTokens(tokens, weights):
return sources_bxnxn
def RootPotentialsFromTokens(root, tokens, weights):
def RootPotentialsFromTokens(root, tokens, weights_arc, weights_source):
r"""Returns root selection potentials computed from tokens and weights.
For each batch of token activations, computes a scalar potential for each root
......@@ -162,7 +166,8 @@ def RootPotentialsFromTokens(root, tokens, weights):
Args:
root: [S] vector of activations for the artificial root token.
tokens: [B,N,T] tensor of batched activations for root tokens.
weights: [S,T] matrix of weights.
weights_arc: [S,T] matrix of weights.
weights_source: [S] vector of weights.
B,N may be statically-unknown, but S,T must be statically-known. The dtype
of all arguments must be compatible.
......@@ -174,25 +179,30 @@ def RootPotentialsFromTokens(root, tokens, weights):
# All arguments must have statically-known rank.
check.Eq(root.get_shape().ndims, 1, 'root must be a vector')
check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')
check.Eq(weights_arc.get_shape().ndims, 2, 'weights_arc must be a matrix')
check.Eq(weights_source.get_shape().ndims, 1,
'weights_source must be a vector')
# All activation dimensions must be statically-known.
num_source_activations = weights.get_shape().as_list()[0]
num_target_activations = weights.get_shape().as_list()[1]
num_source_activations = weights_arc.get_shape().as_list()[0]
num_target_activations = weights_arc.get_shape().as_list()[1]
check.NotNone(num_source_activations, 'unknown source activation dimension')
check.NotNone(num_target_activations, 'unknown target activation dimension')
check.Eq(root.get_shape().as_list()[0], num_source_activations,
'dimension mismatch between weights and root')
'dimension mismatch between weights_arc and root')
check.Eq(tokens.get_shape().as_list()[2], num_target_activations,
'dimension mismatch between weights and tokens')
'dimension mismatch between weights_arc and tokens')
check.Eq(weights_source.get_shape().as_list()[0], num_source_activations,
'dimension mismatch between weights_arc and weights_source')
# All arguments must share the same type.
check.Same([weights.dtype.base_dtype,
root.dtype.base_dtype,
tokens.dtype.base_dtype],
'dtype mismatch')
check.Same([
weights_arc.dtype.base_dtype, weights_source.dtype.base_dtype,
root.dtype.base_dtype, tokens.dtype.base_dtype
], 'dtype mismatch')
root_1xs = tf.expand_dims(root, 0)
weights_source_sx1 = tf.expand_dims(weights_source, 1)
tokens_shape = tf.shape(tokens)
batch_size = tokens_shape[0]
......@@ -200,9 +210,12 @@ def RootPotentialsFromTokens(root, tokens, weights):
# Flatten out the batch dimension so we can use a couple big matmuls.
tokens_bnxt = tf.reshape(tokens, [-1, num_target_activations])
weights_targets_bnxs = tf.matmul(tokens_bnxt, weights, transpose_b=True)
weights_targets_bnxs = tf.matmul(tokens_bnxt, weights_arc, transpose_b=True)
roots_1xbn = tf.matmul(root_1xs, weights_targets_bnxs, transpose_b=True)
# Add in the score for selecting the root as a source.
roots_1xbn += tf.matmul(root_1xs, weights_source_sx1)
# Restore the batch dimension in the output.
roots_bxn = tf.reshape(roots_1xbn, [batch_size, num_tokens])
return roots_bxn
......@@ -354,3 +367,110 @@ def LabelPotentialsFromTokenPairs(sources, targets, weights):
transpose_b=True)
labels_bxnxl = tf.squeeze(labels_bxnxlx1, [3])
return labels_bxnxl
def ValidArcAndTokenMasks(lengths, max_length, dtype=tf.float32):
r"""Returns 0/1 masks for valid arcs and tokens.
Args:
lengths: [B] vector of input sequence lengths.
max_length: Scalar maximum input sequence length, aka M.
dtype: Data type for output mask.
Returns:
[B,M,M] tensor A with 0/1 indicators of valid arcs. Specifically,
A_{b,t,s} = t,s < lengths[b] ? 1 : 0
[B,M] matrix T with 0/1 indicators of valid tokens. Specifically,
T_{b,t} = t < lengths[b] ? 1 : 0
"""
lengths_bx1 = tf.expand_dims(lengths, 1)
sequence_m = tf.range(tf.cast(max_length, lengths.dtype.base_dtype))
sequence_1xm = tf.expand_dims(sequence_m, 0)
# Create vectors of 0/1 indicators for valid tokens. Note that the comparison
# operator will broadcast from [1,M] and [B,1] to [B,M].
valid_token_bxm = tf.cast(sequence_1xm < lengths_bx1, dtype)
# Compute matrices of 0/1 indicators for valid arcs as the outer product of
# the valid token indicator vector with itself.
valid_arc_bxmxm = tf.matmul(
tf.expand_dims(valid_token_bxm, 2), tf.expand_dims(valid_token_bxm, 1))
return valid_arc_bxmxm, valid_token_bxm
def LaplacianMatrix(lengths, arcs, forest=False):
r"""Returns the (root-augmented) Laplacian matrix for a batch of digraphs.
Args:
lengths: [B] vector of input sequence lengths.
arcs: [B,M,M] tensor of arc potentials where entry b,t,s is the potential of
the arc from s to t in the b'th digraph, while b,t,t is the potential of t
as a root. Entries b,t,s where t or s >= lengths[b] are ignored.
forest: Whether to produce a Laplacian for trees or forests.
Returns:
[B,M,M] tensor L with the Laplacian of each digraph, padded with an identity
matrix. More concretely, the padding entries (t or s >= lengths[b]) are:
L_{b,t,t} = 1.0
L_{b,t,s} = 0.0
Note that this "identity matrix padding" ensures that the determinant of
each padded matrix equals the determinant of the unpadded matrix. The
non-padding entries (t,s < lengths[b]) depend on whether the Laplacian is
constructed for trees or forests. For trees:
L_{b,t,0} = arcs[b,t,t]
L_{b,t,t} = \sum_{s < lengths[b], t != s} arcs[b,t,s]
L_{b,t,s} = -arcs[b,t,s]
For forests:
L_{b,t,t} = \sum_{s < lengths[b]} arcs[b,t,s]
L_{b,t,s} = -arcs[b,t,s]
See http://www.aclweb.org/anthology/D/D07/D07-1015.pdf for details, though
note that our matrices are transposed from their notation.
"""
check.Eq(arcs.get_shape().ndims, 3, 'arcs must be rank 3')
dtype = arcs.dtype.base_dtype
arcs_shape = tf.shape(arcs)
batch_size = arcs_shape[0]
max_length = arcs_shape[1]
with tf.control_dependencies([tf.assert_equal(max_length, arcs_shape[2])]):
valid_arc_bxmxm, valid_token_bxm = ValidArcAndTokenMasks(
lengths, max_length, dtype=dtype)
invalid_token_bxm = tf.constant(1, dtype=dtype) - valid_token_bxm
# Zero out all invalid arcs, to avoid polluting bulk summations.
arcs_bxmxm = arcs * valid_arc_bxmxm
zeros_bxm = tf.zeros([batch_size, max_length], dtype)
if not forest:
# For trees, extract the root potentials and exclude them from the sums
# computed below.
roots_bxm = tf.matrix_diag_part(arcs_bxmxm) # only defined for trees
arcs_bxmxm = tf.matrix_set_diag(arcs_bxmxm, zeros_bxm)
# Sum inbound arc potentials for each target token. These sums will form
# the diagonal of the Laplacian matrix. Note that these sums are zero for
# invalid tokens, since their arc potentials were masked out above.
sums_bxm = tf.reduce_sum(arcs_bxmxm, 2)
if forest:
# For forests, zero out the root potentials after computing the sums above
# so we don't cancel them out when we subtract the arc potentials.
arcs_bxmxm = tf.matrix_set_diag(arcs_bxmxm, zeros_bxm)
# The diagonal of the result is the combination of the arc sums, which are
# non-zero only on valid tokens, and the invalid token indicators, which are
# non-zero only on invalid tokens. Note that the latter form the diagonal
# of the identity matrix padding.
diagonal_bxm = sums_bxm + invalid_token_bxm
# Combine sums and negative arc potentials. Note that the off-diagonal
# padding entries will be zero thanks to the arc mask.
laplacian_bxmxm = tf.matrix_diag(diagonal_bxm) - arcs_bxmxm
if not forest:
# For trees, replace the first column with the root potentials.
roots_bxmx1 = tf.expand_dims(roots_bxm, 2)
laplacian_bxmxm = tf.concat([roots_bxmx1, laplacian_bxmxm[:, :, 1:]], 2)
return laplacian_bxmxm
......@@ -31,16 +31,18 @@ class DigraphOpsTest(tf.test.TestCase):
[3, 4]],
[[3, 4],
[2, 3],
[1, 2]]], tf.float32)
[1, 2]]],
tf.float32) # pyformat: disable
target_tokens = tf.constant([[[4, 5, 6],
[5, 6, 7],
[6, 7, 8]],
[[6, 7, 8],
[5, 6, 7],
[4, 5, 6]]], tf.float32)
[4, 5, 6]]],
tf.float32) # pyformat: disable
weights = tf.constant([[2, 3, 5],
[7, 11, 13]],
tf.float32)
tf.float32) # pyformat: disable
arcs = digraph_ops.ArcPotentialsFromTokens(source_tokens, target_tokens,
weights)
......@@ -54,7 +56,7 @@ class DigraphOpsTest(tf.test.TestCase):
[803, 957, 1111]],
[[1111, 957, 803], # reflected through the center
[815, 702, 589],
[519, 447, 375]]])
[519, 447, 375]]]) # pyformat: disable
def testArcSourcePotentialsFromTokens(self):
with self.test_session():
......@@ -63,7 +65,7 @@ class DigraphOpsTest(tf.test.TestCase):
[6, 7, 8]],
[[6, 7, 8],
[5, 6, 7],
[4, 5, 6]]], tf.float32)
[4, 5, 6]]], tf.float32) # pyformat: disable
weights = tf.constant([2, 3, 5], tf.float32)
arcs = digraph_ops.ArcSourcePotentialsFromTokens(tokens, weights)
......@@ -73,7 +75,7 @@ class DigraphOpsTest(tf.test.TestCase):
[73, 73, 73]],
[[73, 73, 73],
[63, 63, 63],
[53, 53, 53]]])
[53, 53, 53]]]) # pyformat: disable
def testRootPotentialsFromTokens(self):
with self.test_session():
......@@ -83,15 +85,17 @@ class DigraphOpsTest(tf.test.TestCase):
[6, 7, 8]],
[[6, 7, 8],
[5, 6, 7],
[4, 5, 6]]], tf.float32)
weights = tf.constant([[2, 3, 5],
[7, 11, 13]],
tf.float32)
[4, 5, 6]]], tf.float32) # pyformat: disable
weights_arc = tf.constant([[2, 3, 5],
[7, 11, 13]],
tf.float32) # pyformat: disable
weights_source = tf.constant([11, 10], tf.float32)
roots = digraph_ops.RootPotentialsFromTokens(root, tokens, weights)
roots = digraph_ops.RootPotentialsFromTokens(root, tokens, weights_arc,
weights_source)
self.assertAllEqual(roots.eval(), [[375, 447, 519],
[519, 447, 375]])
self.assertAllEqual(roots.eval(), [[406, 478, 550],
[550, 478, 406]]) # pyformat: disable
def testCombineArcAndRootPotentials(self):
with self.test_session():
......@@ -100,9 +104,9 @@ class DigraphOpsTest(tf.test.TestCase):
[3, 4, 5]],
[[3, 4, 5],
[2, 3, 4],
[1, 2, 3]]], tf.float32)
[1, 2, 3]]], tf.float32) # pyformat: disable
roots = tf.constant([[6, 7, 8],
[8, 7, 6]], tf.float32)
[8, 7, 6]], tf.float32) # pyformat: disable
potentials = digraph_ops.CombineArcAndRootPotentials(arcs, roots)
......@@ -111,7 +115,7 @@ class DigraphOpsTest(tf.test.TestCase):
[3, 4, 8]],
[[8, 4, 5],
[2, 7, 4],
[1, 2, 6]]])
[1, 2, 6]]]) # pyformat: disable
def testLabelPotentialsFromTokens(self):
with self.test_session():
......@@ -120,12 +124,12 @@ class DigraphOpsTest(tf.test.TestCase):
[5, 6]],
[[6, 5],
[4, 3],
[2, 1]]], tf.float32)
[2, 1]]], tf.float32) # pyformat: disable
weights = tf.constant([[ 2, 3],
[ 5, 7],
[11, 13]], tf.float32)
[11, 13]], tf.float32) # pyformat: disable
labels = digraph_ops.LabelPotentialsFromTokens(tokens, weights)
......@@ -136,7 +140,7 @@ class DigraphOpsTest(tf.test.TestCase):
[ 28, 67, 133]],
[[ 27, 65, 131],
[ 17, 41, 83],
[ 7, 17, 35]]])
[ 7, 17, 35]]]) # pyformat: disable
def testLabelPotentialsFromTokenPairs(self):
with self.test_session():
......@@ -145,13 +149,13 @@ class DigraphOpsTest(tf.test.TestCase):
[5, 6]],
[[6, 5],
[4, 3],
[2, 1]]], tf.float32)
[2, 1]]], tf.float32) # pyformat: disable
targets = tf.constant([[[3, 4],
[5, 6],
[7, 8]],
[[8, 7],
[6, 5],
[4, 3]]], tf.float32)
[4, 3]]], tf.float32) # pyformat: disable
weights = tf.constant([[[ 2, 3],
......@@ -159,7 +163,7 @@ class DigraphOpsTest(tf.test.TestCase):
[[11, 13],
[17, 19]],
[[23, 29],
[31, 37]]], tf.float32)
[31, 37]]], tf.float32) # pyformat: disable
labels = digraph_ops.LabelPotentialsFromTokenPairs(sources, targets,
weights)
......@@ -171,7 +175,114 @@ class DigraphOpsTest(tf.test.TestCase):
[ 736, 2531, 5043]],
[[ 667, 2419, 4857],
[ 303, 1115, 2245],
[ 75, 291, 593]]])
[ 75, 291, 593]]]) # pyformat: disable
def testValidArcAndTokenMasks(self):
with self.test_session():
lengths = tf.constant([1, 2, 3], tf.int64)
max_length = 4
valid_arcs, valid_tokens = digraph_ops.ValidArcAndTokenMasks(
lengths, max_length)
self.assertAllEqual(valid_arcs.eval(),
[[[1, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[1, 1, 1, 0],
[1, 1, 1, 0],
[1, 1, 1, 0],
[0, 0, 0, 0]]]) # pyformat: disable
self.assertAllEqual(valid_tokens.eval(),
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0]]) # pyformat: disable
def testLaplacianMatrixTree(self):
with self.test_session():
pad = 12345.6
arcs = tf.constant([[[ 2, pad, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad]],
[[ 2, 3, pad, pad],
[ 5, 7, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad]],
[[ 2, 3, 5, pad],
[ 7, 11, 13, pad],
[ 17, 19, 23, pad],
[pad, pad, pad, pad]],
[[ 2, 3, 5, 7],
[ 11, 13, 17, 19],
[ 23, 29, 31, 37],
[ 41, 43, 47, 53]]],
tf.float32) # pyformat: disable
lengths = tf.constant([1, 2, 3, 4], tf.int64)
laplacian = digraph_ops.LaplacianMatrix(lengths, arcs)
self.assertAllEqual(laplacian.eval(),
[[[ 2, 0, 0, 0],
[ 0, 1, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 0, 1]],
[[ 2, -3, 0, 0],
[ 7, 5, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 0, 1]],
[[ 2, -3, -5, 0],
[ 11, 20, -13, 0],
[ 23, -19, 36, 0],
[ 0, 0, 0, 1]],
[[ 2, -3, -5, -7],
[ 13, 47, -17, -19],
[ 31, -29, 89, -37],
[ 53, -43, -47, 131]]]) # pyformat: disable
def testLaplacianMatrixForest(self):
with self.test_session():
pad = 12345.6
arcs = tf.constant([[[ 2, pad, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad]],
[[ 2, 3, pad, pad],
[ 5, 7, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad]],
[[ 2, 3, 5, pad],
[ 7, 11, 13, pad],
[ 17, 19, 23, pad],
[pad, pad, pad, pad]],
[[ 2, 3, 5, 7],
[ 11, 13, 17, 19],
[ 23, 29, 31, 37],
[ 41, 43, 47, 53]]],
tf.float32) # pyformat: disable
lengths = tf.constant([1, 2, 3, 4], tf.int64)
laplacian = digraph_ops.LaplacianMatrix(lengths, arcs, forest=True)
self.assertAllEqual(laplacian.eval(),
[[[ 2, 0, 0, 0],
[ 0, 1, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 0, 1]],
[[ 5, -3, 0, 0],
[ -5, 12, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 0, 1]],
[[ 10, -3, -5, 0],
[ -7, 31, -13, 0],
[-17, -19, 59, 0],
[ 0, 0, 0, 1]],
[[ 17, -3, -5, -7],
[-11, 60, -17, -19],
[-23, -29, 120, -37],
[-41, -43, -47, 184]]]) # pyformat: disable
if __name__ == "__main__":
......
......@@ -25,13 +25,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import tensorflow as tf
from google.protobuf import text_format
from dragnn.protos import spec_pb2
from dragnn.python import dragnn_model_saver_lib as saver_lib
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master_spec', None, 'Path to task context with '
......@@ -40,10 +41,12 @@ flags.DEFINE_string('params_path', None, 'Path to trained model parameters.')
flags.DEFINE_string('export_path', '', 'Output path for exported servo model.')
flags.DEFINE_bool('export_moving_averages', False,
'Whether to export the moving average parameters.')
flags.DEFINE_bool('build_runtime_graph', False,
'Whether to build a graph for use by the runtime.')
def export(master_spec_path, params_path, export_path,
export_moving_averages):
def export(master_spec_path, params_path, export_path, export_moving_averages,
build_runtime_graph):
"""Restores a model and exports it in SavedModel form.
This method loads a graph specified by the spec at master_spec_path and the
......@@ -55,6 +58,7 @@ def export(master_spec_path, params_path, export_path,
params_path: Path to the parameters file to export.
export_path: Path to export the SavedModel to.
export_moving_averages: Whether to export the moving average parameters.
build_runtime_graph: Whether to build a graph for use by the runtime.
"""
graph = tf.Graph()
......@@ -70,16 +74,16 @@ def export(master_spec_path, params_path, export_path,
short_to_original = saver_lib.shorten_resource_paths(master_spec)
saver_lib.export_master_spec(master_spec, graph)
saver_lib.export_to_graph(master_spec, params_path, stripped_path, graph,
export_moving_averages)
export_moving_averages, build_runtime_graph)
saver_lib.export_assets(master_spec, short_to_original, stripped_path)
def main(unused_argv):
# Run the exporter.
export(FLAGS.master_spec, FLAGS.params_path,
FLAGS.export_path, FLAGS.export_moving_averages)
export(FLAGS.master_spec, FLAGS.params_path, FLAGS.export_path,
FLAGS.export_moving_averages, FLAGS.build_runtime_graph)
tf.logging.info('Export complete.')
if __name__ == '__main__':
tf.app.run()
app.run(main)
......@@ -164,6 +164,7 @@ def export_to_graph(master_spec,
export_path,
external_graph,
export_moving_averages,
build_runtime_graph,
signature_name='model'):
"""Restores a model and exports it in SavedModel form.
......@@ -177,6 +178,7 @@ def export_to_graph(master_spec,
export_path: Path to export the SavedModel to.
external_graph: A tf.Graph() object to build the graph inside.
export_moving_averages: Whether to export the moving average parameters.
build_runtime_graph: Whether to build a graph for use by the runtime.
signature_name: Name of the signature to insert.
"""
tf.logging.info(
......@@ -189,7 +191,7 @@ def export_to_graph(master_spec,
hyperparam_config.use_moving_average = export_moving_averages
builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
post_restore_hook = builder.build_post_restore_hook()
annotation = builder.add_annotation()
annotation = builder.add_annotation(build_runtime_graph=build_runtime_graph)
builder.add_saver()
# Resets session.
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for dragnn.python.dragnn_model_saver_lib."""
from __future__ import absolute_import
......@@ -26,24 +25,30 @@ import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from dragnn.protos import export_pb2
from dragnn.protos import spec_pb2
from dragnn.python import dragnn_model_saver_lib
from syntaxnet import sentence_pb2
from syntaxnet import test_flags
FLAGS = tf.app.flags.FLAGS
def setUpModule():
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
_DUMMY_TEST_SENTENCE = """
token {
word: "sentence" start: 0 end: 7 break_level: NO_BREAK
}
token {
word: "0" start: 9 end: 9 break_level: SPACE_BREAK
}
token {
word: "." start: 10 end: 10 break_level: NO_BREAK
}
"""
class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
def LoadSpec(self, spec_path):
master_spec = spec_pb2.MasterSpec()
root_dir = os.path.join(FLAGS.test_srcdir,
root_dir = os.path.join(test_flags.source_root(),
'dragnn/python')
with open(os.path.join(root_dir, 'testdata', spec_path), 'r') as fin:
text_format.Parse(fin.read().replace('TOPDIR', root_dir), master_spec)
......@@ -52,7 +57,7 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
def CreateLocalSpec(self, spec_path):
master_spec = self.LoadSpec(spec_path)
master_spec_name = os.path.basename(spec_path)
outfile = os.path.join(FLAGS.test_tmpdir, master_spec_name)
outfile = os.path.join(test_flags.temp_dir(), master_spec_name)
fout = open(outfile, 'w')
fout.write(text_format.MessageToString(master_spec))
return outfile
......@@ -80,16 +85,50 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
# Return a set of all unique paths.
return set(path_list)
def GetHookNodeNames(self, master_spec):
"""Returns hook node names to use in tests.
Args:
master_spec: MasterSpec proto from which to infer hook node names.
Returns:
Tuple of (averaged hook node name, non-averaged hook node name, cell
subgraph hook node name).
Raises:
ValueError: If hook nodes cannot be inferred from the |master_spec|.
"""
# Find an op name we can use for testing runtime hooks. Assume that at
# least one component has a fixed feature (else what is the model doing?).
component_name = None
for component_spec in master_spec.component:
if component_spec.fixed_feature:
component_name = component_spec.name
break
if not component_name:
raise ValueError('Cannot infer hook node names')
non_averaged_hook_name = '{}/fixed_embedding_matrix_0/trimmed'.format(
component_name)
averaged_hook_name = '{}/ExponentialMovingAverage'.format(
non_averaged_hook_name)
cell_subgraph_hook_name = '{}/EXPORT/CellSubgraphSpec'.format(
component_name)
return averaged_hook_name, non_averaged_hook_name, cell_subgraph_hook_name
def testModelExport(self):
# Get the master spec and params for this graph.
master_spec = self.LoadSpec('ud-hungarian.master-spec')
params_path = os.path.join(
FLAGS.test_srcdir, 'dragnn/python/testdata'
test_flags.source_root(),
'dragnn/python/testdata'
'/ud-hungarian.params')
# Export the graph via SavedModel. (Here, we maintain a handle to the graph
# for comparison, but that's usually not necessary.)
export_path = os.path.join(FLAGS.test_tmpdir, 'export')
export_path = os.path.join(test_flags.temp_dir(), 'export')
dragnn_model_saver_lib.clean_output_paths(export_path)
saver_graph = tf.Graph()
shortened_to_original = dragnn_model_saver_lib.shorten_resource_paths(
......@@ -102,7 +141,8 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
params_path,
export_path,
saver_graph,
export_moving_averages=False)
export_moving_averages=False,
build_runtime_graph=False)
# Export the assets as well.
dragnn_model_saver_lib.export_assets(master_spec, shortened_to_original,
......@@ -126,6 +166,165 @@ class DragnnModelSaverLibTest(test_util.TensorFlowTestCase):
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING],
export_path)
averaged_hook_name, non_averaged_hook_name, _ = self.GetHookNodeNames(
master_spec)
# Check that the averaged runtime hook node does not exist.
with self.assertRaises(KeyError):
restored_graph.get_operation_by_name(averaged_hook_name)
# Check that the non-averaged version also does not exist.
with self.assertRaises(KeyError):
restored_graph.get_operation_by_name(non_averaged_hook_name)
def testModelExportWithAveragesAndHooks(self):
# Get the master spec and params for this graph.
master_spec = self.LoadSpec('ud-hungarian.master-spec')
params_path = os.path.join(
test_flags.source_root(),
'dragnn/python/testdata'
'/ud-hungarian.params')
# Export the graph via SavedModel. (Here, we maintain a handle to the graph
# for comparison, but that's usually not necessary.) Note that the export
# path must not already exist.
export_path = os.path.join(test_flags.temp_dir(), 'export2')
dragnn_model_saver_lib.clean_output_paths(export_path)
saver_graph = tf.Graph()
shortened_to_original = dragnn_model_saver_lib.shorten_resource_paths(
master_spec)
dragnn_model_saver_lib.export_master_spec(master_spec, saver_graph)
dragnn_model_saver_lib.export_to_graph(
master_spec,
params_path,
export_path,
saver_graph,
export_moving_averages=True,
build_runtime_graph=True)
# Export the assets as well.
dragnn_model_saver_lib.export_assets(master_spec, shortened_to_original,
export_path)
# Validate that the assets are all in the exported directory.
path_set = self.ValidateAssetExistence(master_spec, export_path)
# This master-spec has 4 unique assets. If there are more, we have not
# uniquified the assets properly.
self.assertEqual(len(path_set), 4)
# Restore the graph from the checkpoint into a new Graph object.
restored_graph = tf.Graph()
restoration_config = tf.ConfigProto(
log_device_placement=False,
intra_op_parallelism_threads=10,
inter_op_parallelism_threads=10)
with tf.Session(graph=restored_graph, config=restoration_config) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING],
export_path)
averaged_hook_name, non_averaged_hook_name, cell_subgraph_hook_name = (
self.GetHookNodeNames(master_spec))
# Check that an averaged runtime hook node exists.
restored_graph.get_operation_by_name(averaged_hook_name)
# Check that the non-averaged version does not exist.
with self.assertRaises(KeyError):
restored_graph.get_operation_by_name(non_averaged_hook_name)
# Load the cell subgraph.
cell_subgraph_bytes = restored_graph.get_tensor_by_name(
cell_subgraph_hook_name + ':0')
cell_subgraph_bytes = cell_subgraph_bytes.eval(
feed_dict={'annotation/ComputeSession/InputBatch:0': []})
cell_subgraph_spec = export_pb2.CellSubgraphSpec()
cell_subgraph_spec.ParseFromString(cell_subgraph_bytes)
tf.logging.info('cell_subgraph_spec = %s', cell_subgraph_spec)
# Sanity check inputs.
for cell_input in cell_subgraph_spec.input:
self.assertGreater(len(cell_input.name), 0)
self.assertGreater(len(cell_input.tensor), 0)
self.assertNotEqual(cell_input.type,
export_pb2.CellSubgraphSpec.Input.TYPE_UNKNOWN)
restored_graph.get_tensor_by_name(cell_input.tensor) # shouldn't raise
# Sanity check outputs.
for cell_output in cell_subgraph_spec.output:
self.assertGreater(len(cell_output.name), 0)
self.assertGreater(len(cell_output.tensor), 0)
restored_graph.get_tensor_by_name(cell_output.tensor) # shouldn't raise
# GetHookNames() finds a component with a fixed feature, so at least the
# first feature ID should exist.
self.assertTrue(
any(cell_input.name == 'fixed_channel_0_index_0_ids'
for cell_input in cell_subgraph_spec.input))
# Most dynamic components produce a logits layer.
self.assertTrue(
any(cell_output.name == 'logits'
for cell_output in cell_subgraph_spec.output))
def testModelExportProducesRunnableModel(self):
# Get the master spec and params for this graph.
master_spec = self.LoadSpec('ud-hungarian.master-spec')
params_path = os.path.join(
test_flags.source_root(),
'dragnn/python/testdata'
'/ud-hungarian.params')
# Export the graph via SavedModel. (Here, we maintain a handle to the graph
# for comparison, but that's usually not necessary.)
export_path = os.path.join(test_flags.temp_dir(), 'export')
dragnn_model_saver_lib.clean_output_paths(export_path)
saver_graph = tf.Graph()
shortened_to_original = dragnn_model_saver_lib.shorten_resource_paths(
master_spec)
dragnn_model_saver_lib.export_master_spec(master_spec, saver_graph)
dragnn_model_saver_lib.export_to_graph(
master_spec,
params_path,
export_path,
saver_graph,
export_moving_averages=False,
build_runtime_graph=False)
# Export the assets as well.
dragnn_model_saver_lib.export_assets(master_spec, shortened_to_original,
export_path)
# Restore the graph from the checkpoint into a new Graph object.
restored_graph = tf.Graph()
restoration_config = tf.ConfigProto(
log_device_placement=False,
intra_op_parallelism_threads=10,
inter_op_parallelism_threads=10)
with tf.Session(graph=restored_graph, config=restoration_config) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING],
export_path)
test_doc = sentence_pb2.Sentence()
text_format.Parse(_DUMMY_TEST_SENTENCE, test_doc)
test_reader_string = test_doc.SerializeToString()
test_inputs = [test_reader_string]
tf_out = sess.run(
'annotation/annotations:0',
feed_dict={'annotation/ComputeSession/InputBatch:0': test_inputs})
# We don't care about accuracy, only that the run sessions don't crash.
del tf_out
if __name__ == '__main__':
googletest.main()
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Diff test that compares two files are identical."""
from absl import flags
import tensorflow as tf
FLAGS = flags.FLAGS
flags.DEFINE_string('actual_file', None, 'File to test.')
flags.DEFINE_string('expected_file', None, 'File with expected contents.')
class DiffTest(tf.test.TestCase):
def testEqualFiles(self):
content_actual = None
content_expected = None
try:
with open(FLAGS.actual_file) as actual:
content_actual = actual.read()
except IOError as e:
self.fail("Error opening '%s': %s" % (FLAGS.actual_file, e.strerror))
try:
with open(FLAGS.expected_file) as expected:
content_expected = expected.read()
except IOError as e:
self.fail("Error opening '%s': %s" % (FLAGS.expected_file, e.strerror))
self.assertTrue(content_actual == content_expected)
if __name__ == '__main__':
tf.test.main()
......@@ -28,7 +28,7 @@ from syntaxnet.util import check
try:
tf.NotDifferentiable('ExtractFixedFeatures')
except KeyError as e:
except KeyError, e:
logging.info(str(e))
......@@ -179,6 +179,8 @@ class MasterBuilder(object):
optimizer: handle to the tf.train Optimizer object used to train this model.
master_vars: dictionary of globally shared tf.Variable objects (e.g.
the global training step and learning rate.)
read_from_avg: Whether to use averaged params instead of normal params.
build_runtime_graph: Whether to build a graph for use by the runtime.
"""
def __init__(self, master_spec, hyperparam_config=None, pool_scope='shared'):
......@@ -197,14 +199,15 @@ class MasterBuilder(object):
ValueError: if a component is not found in the registry.
"""
self.spec = master_spec
self.hyperparams = (spec_pb2.GridPoint()
if hyperparam_config is None else hyperparam_config)
self.hyperparams = (
spec_pb2.GridPoint()
if hyperparam_config is None else hyperparam_config)
_validate_grid_point(self.hyperparams)
self.pool_scope = pool_scope
# Set the graph-level random seed before creating the Components so the ops
# they create will use this seed.
tf.set_random_seed(hyperparam_config.seed)
tf.set_random_seed(self.hyperparams.seed)
# Construct all utility class and variables for each Component.
self.components = []
......@@ -219,19 +222,37 @@ class MasterBuilder(object):
self.lookup_component[comp.name] = comp
self.components.append(comp)
# Add global step variable.
self.master_vars = {}
with tf.variable_scope('master', reuse=False):
# Add global step variable.
self.master_vars['step'] = tf.get_variable(
'step', [], initializer=tf.zeros_initializer(), dtype=tf.int32)
self.master_vars['learning_rate'] = _create_learning_rate(
self.hyperparams, self.master_vars['step'])
# Add learning rate. If the learning rate is optimized externally, then
# just create an assign op.
if self.hyperparams.pbt_optimize_learning_rate:
self.master_vars['learning_rate'] = tf.get_variable(
'learning_rate',
initializer=tf.constant(
self.hyperparams.learning_rate, dtype=tf.float32))
lr_assign_input = tf.placeholder(tf.float32, [],
'pbt/assign/learning_rate/Value')
tf.assign(
self.master_vars['learning_rate'],
value=lr_assign_input,
name='pbt/assign/learning_rate')
else:
self.master_vars['learning_rate'] = _create_learning_rate(
self.hyperparams, self.master_vars['step'])
# Construct optimizer.
self.optimizer = _create_optimizer(self.hyperparams,
self.master_vars['learning_rate'],
self.master_vars['step'])
self.read_from_avg = False
self.build_runtime_graph = False
@property
def component_names(self):
return tuple(c.name for c in self.components)
......@@ -366,8 +387,9 @@ class MasterBuilder(object):
max_index = len(self.components)
else:
if not 0 < max_index <= len(self.components):
raise IndexError('Invalid max_index {} for components {}; handle {}'.
format(max_index, self.component_names, handle.name))
raise IndexError(
'Invalid max_index {} for components {}; handle {}'.format(
max_index, self.component_names, handle.name))
# By default, we train every component supervised.
if not component_weights:
......@@ -375,6 +397,11 @@ class MasterBuilder(object):
if not unroll_using_oracle:
unroll_using_oracle = [True] * max_index
if not max_index <= len(unroll_using_oracle):
raise IndexError(('Invalid max_index {} for unroll_using_oracle {}; '
'handle {}').format(max_index, unroll_using_oracle,
handle.name))
component_weights = component_weights[:max_index]
total_weight = (float)(sum(component_weights))
component_weights = [w / total_weight for w in component_weights]
......@@ -408,10 +435,10 @@ class MasterBuilder(object):
args = (master_state, network_states)
if unroll_using_oracle[component_index]:
handle, component_cost, component_correct, component_total = (tf.cond(
comp.training_beam_size > 1,
lambda: comp.build_structured_training(*args),
lambda: comp.build_greedy_training(*args)))
handle, component_cost, component_correct, component_total = (
tf.cond(comp.training_beam_size > 1,
lambda: comp.build_structured_training(*args),
lambda: comp.build_greedy_training(*args)))
else:
handle = comp.build_greedy_inference(*args, during_training=True)
......@@ -445,6 +472,7 @@ class MasterBuilder(object):
# 1. compute the gradients,
# 2. add an optimizer to update the parameters using the gradients,
# 3. make the ComputeSession handle depend on the optimizer.
gradient_norm = tf.constant(0.)
if compute_gradients:
logging.info('Creating train op with %d variables:\n\t%s',
len(params_to_train),
......@@ -452,8 +480,11 @@ class MasterBuilder(object):
grads_and_vars = self.optimizer.compute_gradients(
cost, var_list=params_to_train)
clipped_gradients = [(self._clip_gradients(g), v)
for g, v in grads_and_vars]
clipped_gradients = [
(self._clip_gradients(g), v) for g, v in grads_and_vars
]
gradient_norm = tf.global_norm(list(zip(*clipped_gradients))[0])
minimize_op = self.optimizer.apply_gradients(
clipped_gradients, global_step=self.master_vars['step'])
......@@ -474,6 +505,7 @@ class MasterBuilder(object):
# Returns named access to common outputs.
outputs = {
'cost': cost,
'gradient_norm': gradient_norm,
'batch': effective_batch,
'metrics': metrics,
}
......@@ -520,7 +552,10 @@ class MasterBuilder(object):
with tf.control_dependencies(control_ops):
return tf.no_op(name='post_restore_hook_master')
def build_inference(self, handle, use_moving_average=False):
def build_inference(self,
handle,
use_moving_average=False,
build_runtime_graph=False):
"""Builds an inference pipeline.
This always uses the whole pipeline.
......@@ -530,25 +565,30 @@ class MasterBuilder(object):
use_moving_average: Whether or not to read from the moving
average variables instead of the true parameters. Note: it is not
possible to make gradient updates when this is True.
build_runtime_graph: Whether to build a graph for use by the runtime.
Returns:
handle: Handle after annotation.
"""
self.read_from_avg = use_moving_average
self.build_runtime_graph = build_runtime_graph
network_states = {}
for comp in self.components:
network_states[comp.name] = component.NetworkState()
handle = dragnn_ops.init_component_data(
handle, beam_size=comp.inference_beam_size, component=comp.name)
master_state = component.MasterState(handle,
dragnn_ops.batch_size(
handle, component=comp.name))
if build_runtime_graph:
batch_size = 1 # runtime uses singleton batches
else:
batch_size = dragnn_ops.batch_size(handle, component=comp.name)
master_state = component.MasterState(handle, batch_size)
with tf.control_dependencies([handle]):
handle = comp.build_greedy_inference(master_state, network_states)
handle = dragnn_ops.write_annotations(handle, component=comp.name)
self.read_from_avg = False
self.build_runtime_graph = False
return handle
def add_training_from_config(self,
......@@ -625,7 +665,10 @@ class MasterBuilder(object):
return self._outputs_with_release(handle, {'input_batch': input_batch},
outputs)
def add_annotation(self, name_scope='annotation', enable_tracing=False):
def add_annotation(self,
name_scope='annotation',
enable_tracing=False,
build_runtime_graph=False):
"""Adds an annotation pipeline to the graph.
This will create the following additional named targets by default, for use
......@@ -640,13 +683,17 @@ class MasterBuilder(object):
enable_tracing: Enabling this will result in two things:
1. Tracing will be enabled during inference.
2. A 'traces' node will be added to the outputs.
build_runtime_graph: Whether to build a graph for use by the runtime.
Returns:
A dictionary of input and output nodes.
"""
with tf.name_scope(name_scope):
handle, input_batch = self._get_session_with_reader(enable_tracing)
handle = self.build_inference(handle, use_moving_average=True)
handle = self.build_inference(
handle,
use_moving_average=True,
build_runtime_graph=build_runtime_graph)
annotations = dragnn_ops.emit_annotations(
handle, component=self.spec.component[-1].name)
......@@ -666,7 +713,7 @@ class MasterBuilder(object):
def add_saver(self):
"""Adds a Saver for all variables in the graph."""
logging.info('Saving variables:\n\t%s',
logging.info('Generating op to save variables:\n\t%s',
'\n\t'.join([x.name for x in tf.global_variables()]))
self.saver = tf.train.Saver(
var_list=[x for x in tf.global_variables()],
......
......@@ -20,7 +20,6 @@ import os.path
import numpy as np
from six.moves import xrange
import tensorflow as tf
from google.protobuf import text_format
......@@ -30,13 +29,12 @@ from dragnn.protos import trace_pb2
from dragnn.python import dragnn_ops
from dragnn.python import graph_builder
from syntaxnet import sentence_pb2
from syntaxnet import test_flags
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
FLAGS = tf.app.flags.FLAGS
_DUMMY_GOLD_SENTENCE = """
token {
......@@ -151,13 +149,6 @@ token {
]
def setUpModule():
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
def _as_op(x):
"""Always returns the tf.Operation associated with a node."""
return x.op if isinstance(x, tf.Tensor) else x
......@@ -244,7 +235,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
def LoadSpec(self, spec_path):
master_spec = spec_pb2.MasterSpec()
testdata = os.path.join(FLAGS.test_srcdir,
testdata = os.path.join(test_flags.source_root(),
'dragnn/core/testdata')
with open(os.path.join(testdata, spec_path), 'r') as fin:
text_format.Parse(fin.read().replace('TESTDATA', testdata), master_spec)
......@@ -445,7 +436,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
self.assertEqual(expected_num_actions, correct_val)
self.assertEqual(expected_num_actions, total_val)
builder.saver.save(sess, os.path.join(FLAGS.test_tmpdir, 'model'))
builder.saver.save(sess, os.path.join(test_flags.temp_dir(), 'model'))
logging.info('Running test.')
logging.info('Printing annotations')
......
......@@ -27,8 +27,7 @@ from dragnn.python import lexicon
from syntaxnet import parser_trainer
from syntaxnet import task_spec_pb2
FLAGS = tf.app.flags.FLAGS
from syntaxnet import test_flags
_EXPECTED_CONTEXT = r"""
......@@ -46,13 +45,6 @@ input { name: "known-word-map" Part { file_pattern: "/tmp/known-word-map" } }
"""
def setUpModule():
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
class LexiconTest(tf.test.TestCase):
def testCreateLexiconContext(self):
......@@ -62,8 +54,8 @@ class LexiconTest(tf.test.TestCase):
lexicon.create_lexicon_context('/tmp'), expected_context)
def testBuildLexicon(self):
empty_input_path = os.path.join(FLAGS.test_tmpdir, 'empty-input')
lexicon_output_path = os.path.join(FLAGS.test_tmpdir, 'lexicon-output')
empty_input_path = os.path.join(test_flags.temp_dir(), 'empty-input')
lexicon_output_path = os.path.join(test_flags.temp_dir(), 'lexicon-output')
with open(empty_input_path, 'w'):
pass
......
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads mst_ops shared library."""
import os.path
import tensorflow as tf
tf.load_op_library(
os.path.join(tf.resource_loader.get_data_files_path(), 'mst_cc_impl.so'))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment