Commit a4bb31d0 authored by Terry Koo's avatar Terry Koo
Browse files

Export @195097388.

parent dea7ecf6
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow ops for maximum spanning tree problems."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import dragnn.python.load_mst_cc_impl
from dragnn.mst.ops import gen_mst_ops
from dragnn.python import digraph_ops
from syntaxnet.util import check
# Re-export the generated MST op.
maximum_spanning_tree = gen_mst_ops.maximum_spanning_tree
@tf.RegisterGradient("MaximumSpanningTree")
def maximum_spanning_tree_gradient(mst_op, d_loss_d_max_scores, *_):
"""Returns a subgradient of the MaximumSpanningTree op.
Note that MaximumSpanningTree is only differentiable w.r.t. its |scores| input
and its |max_scores| output.
Args:
mst_op: The MaximumSpanningTree op being differentiated.
d_loss_d_max_scores: [B] vector where entry b is the gradient of the network
loss w.r.t. entry b of the |max_scores| output of the
|mst_op|.
*_: The gradients w.r.t. the other outputs; ignored.
Returns:
1. None, since the op is not differentiable w.r.t. its |num_nodes| input.
2. [B,M,M] tensor where entry b,t,s is a subgradient of the network loss
w.r.t. entry b,t,s of the |scores| input, with the same dtype as
|d_loss_d_max_scores|.
"""
dtype = d_loss_d_max_scores.dtype.base_dtype
check.NotNone(dtype)
argmax_sources_bxm = mst_op.outputs[1]
input_dim = tf.shape(argmax_sources_bxm)[1] # M in the docstring
# The one-hot argmax is a subgradient of max. Convert the batch of maximal
# spanning trees into 0/1 indicators, then scale them by the relevant output
# gradients from |d_loss_d_max_scores|. Note that |d_loss_d_max_scores| must
# be reshaped in order for it to broadcast across the batch dimension.
indicators_bxmxm = tf.one_hot(argmax_sources_bxm, input_dim, dtype=dtype)
d_loss_d_max_scores_bx1 = tf.expand_dims(d_loss_d_max_scores, -1)
d_loss_d_max_scores_bx1x1 = tf.expand_dims(d_loss_d_max_scores_bx1, -1)
d_loss_d_scores_bxmxm = indicators_bxmxm * d_loss_d_max_scores_bx1x1
return None, d_loss_d_scores_bxmxm
def log_partition_function(num_nodes,
scores,
forest=False,
max_dynamic_range=None):
r"""Returns the log of the sum-of-product of spanning trees or forests.
Computing the sum-of-product in the log domain reduces the chance of overflow
or underflow, and ML techniques (e.g., CRF loss functions) typically require
the log partition function anyways. For similar reasons, the scores input is
assumed to be specified in the log domain.
The partition function is caluclated via application of the Matrix-Tree
theorem; see the following for details:
https://en.wikipedia.org/wiki/Kirchhoff%27s_theorem
http://www.aclweb.org/anthology/D/D07/D07-1015.pdf
Computing the gradient of the log partition function requires inverting the
Laplacian matrix. Numerical issues may occur if the Laplacian is singular or
nearly-so. (Intuitively, the Laplacian will be close to singular when the
input scores strongly favor invalid structures such as cycles). In the EMNLP
paper, we alleviated the numerical issues by clipping the difference between
the minimum and maximum score for each node to 20 (in the log domain). The
|max_dynamic_range| argument can be used for this purpose.
TODO(googleuser): Try improving the condition number of the Laplacian matrix
directly, instead of using the indirect approach above. For example, one
could add c*I to the Laplacian (i.e., Tikhonov regularization).
Args:
num_nodes: [B] vector of graph sizes per batch item.
scores: [B,M,M] tensor of padded batched arc and root scores, in the format
used by the maximum_spanning_tree() op. Padding values must be finite.
forest: If true, sum over spanning forests instead of trees.
max_dynamic_range: If specified, incoming scores for each node are clipped
to at most this far from the maximum such score (in the log domain).
Returns:
[B] vector Z of log partition function values, where
Z[b] = log(
\sum_{tree spanning batch item b}
score(root_of(tree)) \prod_{arc in tree} score(arc))
"""
orig_dtype = scores.dtype.base_dtype
scores_bxmxm = tf.to_double(scores) # use doubles to reduce under/overflow
shape_bxmxm = tf.shape(scores_bxmxm)
batch_size = shape_bxmxm[0]
max_nodes = shape_bxmxm[1]
total_nodes = batch_size * max_nodes
# To eliminate overflow, we locally normalize the scores. Specifically, for
# each node we divide its incoming arc scores and root selection score by the
# maximum such score. Since each node in a tree must select exactly one of
# these scores (i.e., it is either a root or has exactly one incoming arc),
# the local normalization factors are identical for all trees and can thus be
# factored out of the sum over trees.
#
# More concretely, we find the maximum per node, divide all scores for that
# node by the maximum, and then find the partition function of the normalized
# scores. Then we recover the un-normalized partition function by multiplying
# the per-node maxima back in. This final step is performed in the log domain
# to avoid overflow.
#
# Note that underflow is still possible, but unlikely as long as the scores
# are close to feasible (i.e., there is not too much mass on non-trees). The
# |max_dynamic_range| argument can be used to mitigate this.
# Finding the maximum incoming score is difficult, because the batch padding
# may contain arbitrary values. We restrict the maximization to valid arcs
# using tf.unsorted_segment_max() with a specially-constructed set of IDs.
_, valid_tokens_bxm = digraph_ops.ValidArcAndTokenMasks(
num_nodes, max_nodes, dtype=tf.int32)
# Create a tensor of "target IDs". In each row of each sub-matrix, the
# positions of valid source tokens are filled with the 1-origin index of that
# row in the entire batch, and zero elsewhere. For example, given a batch
# with num_nodes=[2, 3] we might have
# [[[1, 1, 0],
# [2, 2, 0],
# [3, 3, 0]],
# [[4, 4, 4],
# [5, 5, 5],
# [6, 6, 6]]]
#
# TODO(googleuser): The dynamic masking is pretty awkward. Find an op that does
# this (I looked, but maybe not hard enough), or write a custom op for this.
valid_tokens_bx1xm = tf.expand_dims(valid_tokens_bxm, 1)
valid_sources_bxmxm = tf.tile(valid_tokens_bx1xm, [1, max_nodes, 1])
sequence_bm = 1 + tf.range(total_nodes, dtype=tf.int32)
sequence_bxmx1 = tf.reshape(sequence_bm, [batch_size, max_nodes, 1])
target_ids_bxmxm = valid_sources_bxmxm * sequence_bxmx1
max_scores_bm1 = tf.unsorted_segment_max(scores_bxmxm, target_ids_bxmxm,
total_nodes + 1)
max_scores_bm = max_scores_bm1[1:] # ID 0 corresponds to padding
# Similar to above, we need to sum over the valid tokens. We analogously use
# tf.unsorted_segment_sum() with a specially-constructed set of "batch IDs".
sequence_b = 1 + tf.range(batch_size, dtype=tf.int32)
sequence_bx1 = tf.expand_dims(sequence_b, 1)
batch_ids_bxm = valid_tokens_bxm * sequence_bx1
batch_ids_bm = tf.reshape(batch_ids_bxm, [-1])
log_normalization_factor_b1 = tf.unsorted_segment_sum(
max_scores_bm, batch_ids_bm, batch_size + 1)
log_normalization_factor_b = log_normalization_factor_b1[1:]
# Locally-normalize and optionally clip the scores.
max_scores_bxmx1 = tf.reshape(max_scores_bm, [batch_size, max_nodes, 1])
scores_bxmxm -= max_scores_bxmx1
if max_dynamic_range is not None:
# After normalization, the scores are non-positive with max=0, so the
# |max_dynamic_range| can be applied directly.
#
# PyLint thinks "-max_dynamic_range" is invalid because it defaults to None.
scores_bxmxm = tf.maximum(scores_bxmxm, -max_dynamic_range)
scores_bxmxm = tf.exp(scores_bxmxm)
# Apply the Matrix-Tree theorem.
exp_normalized_laplacian_bxmxm = digraph_ops.LaplacianMatrix(
num_nodes, scores_bxmxm, forest=forest)
log_normalized_partition_function_b = tf.log(
tf.matrix_determinant(exp_normalized_laplacian_bxmxm))
# Reapply the normalization factor that was divided out.
log_partition_function_b = (
log_normalized_partition_function_b + log_normalization_factor_b)
return tf.cast(log_partition_function_b, orig_dtype)
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for maximum spanning tree ops."""
import math
import numpy as np
import tensorflow as tf
from dragnn.python import mst_ops
class MstOpsTest(tf.test.TestCase):
"""Testing rig."""
def testMaximumSpanningTree(self):
"""Tests that the MST op can recover a simple tree."""
with self.test_session() as session:
# The first batch element prefers 3 as root, then 3->0->1->2, for a total
# score of 4+2+1=7. The second batch element is smaller and has reversed
# scores, so 0 is root and 0->2->1.
num_nodes = tf.constant([4, 3], tf.int32)
scores = tf.constant([[[0, 0, 0, 0],
[1, 0, 0, 0],
[1, 2, 0, 0],
[1, 2, 3, 4]],
[[4, 3, 2, 9],
[0, 0, 2, 9],
[0, 0, 0, 9],
[9, 9, 9, 9]]], tf.int32) # pyformat: disable
mst_outputs = mst_ops.maximum_spanning_tree(
num_nodes, scores, forest=False)
max_scores, argmax_sources = session.run(mst_outputs)
tf.logging.info('\nmax_scores=%s\nargmax_sources=\n%s', max_scores,
argmax_sources)
self.assertAllEqual(max_scores, [7, 6])
self.assertAllEqual(argmax_sources, [[3, 0, 1, 3],
[0, 2, 0, -1]]) # pyformat: disable
def testMaximumSpanningTreeGradient(self):
"""Tests the MST max score gradient."""
with self.test_session() as session:
num_nodes = tf.constant([4, 3], tf.int32)
scores = tf.constant([[[0, 0, 0, 0],
[1, 0, 0, 0],
[1, 2, 0, 0],
[1, 2, 3, 4]],
[[4, 3, 2, 9],
[0, 0, 2, 9],
[0, 0, 0, 9],
[9, 9, 9, 9]]], tf.int32) # pyformat: disable
mst_ops.maximum_spanning_tree(num_nodes, scores, forest=False, name='MST')
mst_op = session.graph.get_operation_by_name('MST')
d_loss_d_max_scores = tf.constant([3, 7], tf.float32)
d_loss_d_num_nodes, d_loss_d_scores = (
mst_ops.maximum_spanning_tree_gradient(mst_op, d_loss_d_max_scores))
# The num_nodes input is non-differentiable.
self.assertTrue(d_loss_d_num_nodes is None)
tf.logging.info('\nd_loss_d_scores=\n%s', d_loss_d_scores.eval())
self.assertAllEqual(d_loss_d_scores.eval(),
[[[0, 0, 0, 3],
[3, 0, 0, 0],
[0, 3, 0, 0],
[0, 0, 0, 3]],
[[7, 0, 0, 0],
[0, 0, 7, 0],
[7, 0, 0, 0],
[0, 0, 0, 0]]]) # pyformat: disable
def testMaximumSpanningTreeGradientError(self):
"""Numerically validates the max score gradient."""
with self.test_session():
# The maximum-spanning-tree-score function, as a max of linear functions,
# is piecewise-linear (i.e., faceted). The numerical gradient estimate
# may be inaccurate if the epsilon ball used for the estimate crosses an
# edge from one facet to another. To avoid spurious errors, we manually
# set the sample point so the epsilon ball fits in a facet. Or in other
# words, we set the scores so there is a non-trivial margin between the
# best and second-best trees.
scores_raw = [[[0, 0, 0, 0],
[1, 0, 0, 0],
[1, 2, 0, 0],
[1, 2, 3, 4]],
[[4, 3, 2, 9],
[0, 0, 2, 9],
[0, 0, 0, 9],
[9, 9, 9, 9]]] # pyformat: disable
# Use 64-bit floats to reduce numerical error.
scores = tf.constant(scores_raw, tf.float64)
init_scores = np.array(scores_raw)
num_nodes = tf.constant([4, 3], tf.int32)
max_scores = mst_ops.maximum_spanning_tree(
num_nodes, scores, forest=False)[0]
gradient_error = tf.test.compute_gradient_error(
scores, [2, 4, 4], max_scores, [2], init_scores)
tf.logging.info('gradient_error=%s', gradient_error)
def testLogPartitionFunctionOneTree(self):
"""Tests the log partition function with one feasible tree with score 1."""
with self.test_session():
for forest in [False, True]:
# Each score matrix supports exactly one tree with score=1*1*1, and
# the rest with score=0. Thus the log partition function will be 1.0
# in each case.
pad = 12345.6
scores = tf.constant([[[ 1, pad, pad],
[pad, pad, pad],
[pad, pad, pad]],
[[ 1, 0, pad],
[ 1, 0, pad],
[pad, pad, pad]],
[[ 1, 0, 0],
[ 1, 0, 0],
[ 0, 1, 0]]],
tf.float64) # pyformat: disable
scores = tf.log(scores)
num_nodes = tf.constant([1, 2, 3], tf.int32)
log_partition_functions = mst_ops.log_partition_function(
num_nodes, scores, forest=forest)
self.assertAlmostEqual(tf.exp(log_partition_functions[0]).eval(), 1.0)
self.assertAlmostEqual(tf.exp(log_partition_functions[1]).eval(), 1.0)
self.assertAlmostEqual(tf.exp(log_partition_functions[2]).eval(), 1.0)
def testLogPartitionFunctionOneTreeScaled(self):
"""Tests the log partition function with one feasible tree."""
with self.test_session():
for forest in [False, True]:
# Each score matrix supports exactly one tree with varying score, and
# the rest with score=0. Thus the log partition function will equal
# the score of that single tree in each case.
pad = 12345.6
scores = tf.constant([[[ 2, pad, pad],
[pad, pad, pad],
[pad, pad, pad]],
[[ 3, 0, pad],
[ 5, 0, pad],
[pad, pad, pad]],
[[ 7, 0, 0],
[ 11, 0, 0],
[ 0, 13, 0]]],
tf.float64) # pyformat: disable
scores = tf.log(scores)
num_nodes = tf.constant([1, 2, 3], tf.int32)
log_partition_functions = mst_ops.log_partition_function(
num_nodes, scores, forest=forest)
self.assertAlmostEqual(tf.exp(log_partition_functions[0]).eval(), 2.0)
self.assertAlmostEqual(
tf.exp(log_partition_functions[1]).eval(), 3.0 * 5.0)
self.assertAlmostEqual(
tf.exp(log_partition_functions[2]).eval(), 7.0 * 11.0 * 13.0)
def testLogPartitionFunctionTwoTreesScaled(self):
"""Tests the log partition function with two feasible trees."""
with self.test_session():
for forest in [False, True]:
# Each score matrix supports exactly two trees with varying score, and
# the rest with score=0. Thus the log partition function will equal
# the sum of scores of those two trees in each case.
pad = 12345.6
scores = tf.constant([[[ 2, 0, 0, pad],
[ 3, 0, 0, pad],
[ 5, 7, 0, pad],
[pad, pad, pad, pad]],
[[ 0, 11, 0, 13],
[ 0, 17, 0, 0],
[ 0, 19, 0, 0],
[ 0, 23, 0, 0]]],
tf.float64) # pyformat: disable
scores = tf.log(scores)
num_nodes = tf.constant([3, 4], tf.int32)
log_partition_functions = mst_ops.log_partition_function(
num_nodes, scores, forest=forest)
self.assertAlmostEqual(
tf.exp(log_partition_functions[0]).eval(),
2.0 * 3.0 * 5.0 + 2.0 * 3.0 * 7.0)
self.assertAlmostEqual(
tf.exp(log_partition_functions[1]).eval(),
11.0 * 17.0 * 19.0 * 23.0 + 13.0 * 17.0 * 19.0 * 23.0)
def testLogPartitionFunctionInfeasible(self):
"""Tests the log partition function on infeasible scores."""
with self.test_session():
for forest in [False, True]:
# The scores form cycles of various sizes. Note that one can compute
# the partition function for infeasible scores---it's the gradient that
# may be impacted by numerical error.
pad = 12345.6
scores = tf.constant([[[ 0, 1, pad, pad],
[ 1, 0, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad]],
[[ 0, 1, 0, pad],
[ 0, 0, 1, pad],
[ 1, 0, 0, pad],
[pad, pad, pad, pad]],
[[ 0, 1, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 0, 1],
[ 1, 0, 0, 0]]],
tf.float64) # pyformat: disable
scores = tf.log(scores)
num_nodes = tf.constant([2, 3, 4], tf.int32)
log_partition_functions = mst_ops.log_partition_function(
num_nodes, scores, forest=forest)
self.assertAlmostEqual(tf.exp(log_partition_functions[0]).eval(), 0.0)
self.assertAlmostEqual(tf.exp(log_partition_functions[1]).eval(), 0.0)
self.assertAlmostEqual(tf.exp(log_partition_functions[2]).eval(), 0.0)
def testLogPartitionFunctionAllTrees(self):
"""Tests the log partition function with all trees feasible."""
with self.test_session():
for forest in [False, True]:
# The scores allow all trees. Using Cayley's formula, the
# number of directed spanning trees and forests in a complete
# digraph of n nodes is n^{n-1} and (n+1)^{n-1}, respectively.
# https://en.wikipedia.org/wiki/Cayley%27s_formula
scores = tf.zeros([10, 10, 10], tf.float64) # = 1 in log domain
num_nodes = tf.range(1, 11, dtype=tf.int32)
log_partition_functions = mst_ops.log_partition_function(
num_nodes, scores, forest=forest)
base_offset = 1 if forest else 0 # n+1 for forest, n for tree
for size in range(1, 11):
self.assertAlmostEqual(log_partition_functions[size - 1].eval(),
(size - 1) * math.log(size + base_offset))
def testLogPartitionFunctionWithVeryHighValues(self):
"""Tests the overflow protection in the log partition function."""
with self.test_session():
for forest in [False, True]:
# Set the scores to very high values to test overflow protection.
scores = 1000 * tf.ones([10, 10, 10], tf.float64)
num_nodes = tf.range(1, 11, dtype=tf.int32)
log_partition_functions = mst_ops.log_partition_function(
num_nodes, scores, forest=forest)
base_offset = 1 if forest else 0 # n+1 for forest, n for tree
for size in range(1, 11):
self.assertAlmostEqual(
log_partition_functions[size - 1].eval(),
(size - 1) * math.log(size + base_offset) + size * 1000)
def testLogPartitionFunctionWithVeryLowValues(self):
"""Tests the underflow protection in the log partition function."""
with self.test_session():
for forest in [False, True]:
# Set the scores to very low values to test underflow protection.
scores = -1000 * tf.ones([10, 10, 10], tf.float64)
num_nodes = tf.range(1, 11, dtype=tf.int32)
log_partition_functions = mst_ops.log_partition_function(
num_nodes, scores, forest=forest)
base_offset = 1 if forest else 0 # n+1 for forest, n for tree
for size in range(1, 11):
self.assertAlmostEqual(
log_partition_functions[size - 1].eval(),
(size - 1) * math.log(size + base_offset) - size * 1000)
def testLogPartitionFunctionGradientError(self):
"""Validates the log partition function gradient."""
with self.test_session():
for forest in [False, True]:
# To avoid numerical issues, provide score matrices that are weighted
# towards feasible trees or forests.
scores_raw = [[[0, 0, 0, 0],
[1, 0, 0, 0],
[1, 2, 0, 0],
[1, 2, 3, 4]],
[[4, 3, 2, 9],
[0, 0, 2, 9],
[0, 0, 0, 9],
[9, 9, 9, 9]]] # pyformat: disable
scores = tf.constant(scores_raw, tf.float64)
init_scores = np.array(scores_raw)
num_nodes = tf.constant([4, 3], tf.int32)
log_partition_functions = mst_ops.log_partition_function(
num_nodes, scores, forest=forest)
gradient_error = tf.test.compute_gradient_error(
scores, [2, 4, 4], log_partition_functions, [2], init_scores)
tf.logging.info('forest=%s gradient_error=%s', forest, gradient_error)
self.assertLessEqual(gradient_error, 1e-7)
def testLogPartitionFunctionGradientErrorFailsIfInfeasible(self):
"""Tests that the partition function gradient fails on infeasible scores."""
with self.test_session():
for forest in [False, True]:
# The scores form cycles of various sizes.
pad = 12345.6
scores_raw = [[[ 0, 1, pad, pad],
[ 1, 0, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad]],
[[ 0, 1, 0, pad],
[ 0, 0, 1, pad],
[ 1, 0, 0, pad],
[pad, pad, pad, pad]],
[[ 0, 1, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 0, 1],
[ 1, 0, 0, 0]]] # pyformat: disable
scores = tf.log(scores_raw)
init_scores = np.log(np.array(scores_raw))
num_nodes = tf.constant([2, 3, 4], tf.int32)
log_partition_functions = mst_ops.log_partition_function(
num_nodes, scores, forest=forest)
with self.assertRaises(Exception):
tf.test.compute_gradient_error(
scores, [3, 4, 4], log_partition_functions, [3], init_scores)
def testLogPartitionFunctionGradientErrorOkIfInfeasibleWithClipping(self):
"""Tests that the log partition function gradient is OK after clipping."""
with self.test_session():
for forest in [False, True]:
# The scores form cycles of various sizes.
pad = 12345.6
scores_raw = [[[ 0, 1, pad, pad],
[ 1, 0, pad, pad],
[pad, pad, pad, pad],
[pad, pad, pad, pad]],
[[ 0, 1, 0, pad],
[ 0, 0, 1, pad],
[ 1, 0, 0, pad],
[pad, pad, pad, pad]],
[[ 0, 1, 0, 0],
[ 0, 0, 1, 0],
[ 0, 0, 0, 1],
[ 1, 0, 0, 0]]] # pyformat: disable
scores = tf.log(scores_raw)
init_scores = np.log(np.array(scores_raw))
num_nodes = tf.constant([2, 3, 4], tf.int32)
log_partition_functions = mst_ops.log_partition_function(
num_nodes, scores, forest=forest, max_dynamic_range=10)
gradient_error = tf.test.compute_gradient_error(
scores, [3, 4, 4], log_partition_functions, [3], init_scores)
tf.logging.info('forest=%s gradient_error=%s', forest, gradient_error)
# There's still a lot of error.
self.assertLessEqual(gradient_error, 1e-3)
if __name__ == '__main__':
tf.test.main()
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DRAGNN wrappers for the MST solver."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from dragnn.python import mst_ops
from dragnn.python import network_units
from syntaxnet.util import check
class MstSolverNetwork(network_units.NetworkUnitInterface):
"""Network unit that performs MST prediction with structured loss.
Parameters:
forest: If true, solve for a spanning forest instead of a spanning tree.
loss: The loss function for training. Select from
softmax: Default unstructured softmax (prediction is still structured).
m3n: Max-Margin Markov Networks loss.
crf_max_dynamic_range: Max dynamic range for the log partition function.
Links:
lengths: [B, 1] sequence lengths per batch item.
scores: [B * N, N] matrix of padded batched arc scores.
Layers:
lengths: [B] sequence lengths per batch item.
scores: [B, N, N] tensor of padded batched arc scores.
logits: [B * N, N] matrix of padded batched arc scores.
arcs: [B * N, N] matrix of padded batched 0/1 indicators for MST arcs.
"""
def __init__(self, component):
"""Initializes layers.
Args:
component: Parent ComponentBuilderBase object.
"""
layers = [
network_units.Layer(self, 'lengths', -1),
network_units.Layer(self, 'scores', -1),
network_units.Layer(self, 'logits', -1),
network_units.Layer(self, 'arcs', -1),
]
super(MstSolverNetwork, self).__init__(component, init_layers=layers)
self._attrs = network_units.get_attrs_with_defaults(
component.spec.network_unit.parameters,
defaults={
'forest': False,
'loss': 'softmax',
'crf_max_dynamic_range': 20,
})
check.Eq(
len(self._fixed_feature_dims.items()), 0, 'Expected no fixed features')
check.Eq(
len(self._linked_feature_dims.items()), 2,
'Expected two linked features')
check.In('lengths', self._linked_feature_dims,
'Missing required linked feature')
check.In('scores', self._linked_feature_dims,
'Missing required linked feature')
def create(self,
fixed_embeddings,
linked_embeddings,
context_tensor_arrays,
attention_tensor,
during_training,
stride=None):
"""Forwards the lengths and scores."""
check.NotNone(stride, 'MstSolverNetwork requires stride')
lengths = network_units.lookup_named_tensor('lengths', linked_embeddings)
lengths_b = tf.to_int32(tf.squeeze(lengths.tensor, [1]))
scores = network_units.lookup_named_tensor('scores', linked_embeddings)
scores_bnxn = scores.tensor
max_length = tf.shape(scores_bnxn)[1]
scores_bxnxn = tf.reshape(scores_bnxn, [stride, max_length, max_length])
_, argmax_sources_bxn = mst_ops.maximum_spanning_tree(
forest=self._attrs['forest'], num_nodes=lengths_b, scores=scores_bxnxn)
argmax_sources_bn = tf.reshape(argmax_sources_bxn, [-1])
arcs_bnxn = tf.one_hot(argmax_sources_bn, max_length, dtype=tf.float32)
return [lengths_b, scores_bxnxn, scores_bnxn, arcs_bnxn]
def get_logits(self, network_tensors):
return network_tensors[self.get_layer_index('logits')]
def get_bulk_predictions(self, stride, network_tensors):
return network_tensors[self.get_layer_index('arcs')]
def compute_bulk_loss(self, stride, network_tensors, gold):
"""See base class."""
if self._attrs['loss'] == 'softmax':
return (None, None, None) # fall back to default bulk softmax
lengths_b, scores_bxnxn, _, arcs_bnxn = network_tensors
max_length = tf.shape(scores_bxnxn)[2]
arcs_bxnxn = tf.reshape(arcs_bnxn, [stride, max_length, max_length])
gold_bxn = tf.reshape(gold, [stride, max_length])
gold_bxnxn = tf.one_hot(gold_bxn, max_length, dtype=tf.float32)
loss = self._compute_loss(lengths_b, scores_bxnxn, gold_bxnxn)
correct = tf.reduce_sum(tf.to_int32(arcs_bxnxn * gold_bxnxn))
total = tf.reduce_sum(lengths_b)
return loss, correct, total
def _compute_loss(self, lengths, scores, gold):
"""Computes the configured structured loss for a batch.
Args:
lengths: [B] sequence lengths per batch item.
scores: [B, N, N] tensor of padded batched arc scores.
gold: [B, N, N] tensor of 0/1 indicators for gold arcs.
Returns:
Scalar sum of losses across the batch.
"""
# Dispatch to one of the _compute_*_loss() methods.
method_name = '_compute_%s_loss' % self._attrs['loss']
loss_b = getattr(self, method_name)(lengths, scores, gold)
return tf.reduce_sum(loss_b)
def _compute_m3n_loss(self, lengths, scores, gold):
"""Computes the M3N-style structured hinge loss for a batch."""
# Perform hamming-loss-augmented inference.
gold_scores_b = tf.reduce_sum(scores * gold, axis=[1, 2])
hamming_loss_bxnxn = 1 - gold
scores_bxnxn = scores + hamming_loss_bxnxn
max_scores_b, _ = mst_ops.maximum_spanning_tree(
num_nodes=lengths, scores=scores_bxnxn, forest=self._attrs['forest'])
return max_scores_b - gold_scores_b
def _compute_crf_loss(self, lengths, scores, gold):
"""Computes the negative CRF log-probability for a batch."""
# The |scores| are assumed to be in the log domain.
log_gold_scores_b = tf.reduce_sum(scores * gold, axis=[1, 2])
log_partition_functions_b = mst_ops.log_partition_function(
num_nodes=lengths,
scores=scores,
forest=self._attrs['forest'],
max_dynamic_range=self._attrs['crf_max_dynamic_range'])
return log_partition_functions_b - log_gold_scores_b # negative log-prob
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for DRAGNN wrappers for the MST solver."""
import math
import tensorflow as tf
from google.protobuf import text_format
from dragnn.protos import spec_pb2
from dragnn.python import mst_units
from dragnn.python import network_units
_MASTER_SPEC = r"""
component {
name: 'test'
linked_feature {
name: 'lengths'
size: 1
embedding_dim: -1
fml: 'input.focus'
source_translator: 'identity'
source_component: 'previous'
source_layer: 'lengths'
}
linked_feature {
name: 'scores'
size: 1
embedding_dim: -1
fml: 'input.focus'
source_translator: 'identity'
source_component: 'previous'
source_layer: 'scores'
}
}
"""
class MockNetwork(object):
def get_layer_size(self, unused_name):
return -1
class MockComponent(object):
def __init__(self, master, component_spec):
self.master = master
self.spec = component_spec
self.name = component_spec.name
self.beam_size = 1
self.num_actions = -1
self.network = MockNetwork()
class MockMaster(object):
def __init__(self, build_runtime_graph=False):
self.spec = spec_pb2.MasterSpec()
text_format.Parse(_MASTER_SPEC, self.spec)
self.hyperparams = spec_pb2.GridPoint()
self.lookup_component = {
'previous': MockComponent(self, spec_pb2.ComponentSpec())
}
self.build_runtime_graph = build_runtime_graph
class MstSolverNetworkTest(tf.test.TestCase):
def setUp(self):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf.reset_default_graph()
def testCreate(self):
with self.test_session():
master = MockMaster()
component = MockComponent(master, master.spec.component[0])
component.network = mst_units.MstSolverNetwork(component)
stride = 1
lengths = tf.constant([[3]], dtype=tf.int64)
scores = tf.constant([[1.0, 0.5, 0.5],
[2.0, 0.5, 0.5],
[0.5, 3.0, 0.5]],
dtype=tf.float32) # pyformat: disable
linked_embeddings = [
network_units.NamedTensor(lengths, 'lengths'),
network_units.NamedTensor(scores, 'scores')
]
network_tensors = component.network.create([], linked_embeddings, [],
None, False, stride)
self.assertAllEqual(network_tensors[0].eval(), [3])
self.assertAllEqual(network_tensors[1].eval(),
[[[1.0, 0.5, 0.5],
[2.0, 0.5, 0.5],
[0.5, 3.0, 0.5]]]) # pyformat: disable
self.assertAllEqual(network_tensors[2].eval(),
[[1.0, 0.5, 0.5],
[2.0, 0.5, 0.5],
[0.5, 3.0, 0.5]]) # pyformat: disable
self.assertAllEqual(network_tensors[3].eval(),
[[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0]]) # pyformat: disable
def testGetBulkPredictions(self):
with self.test_session():
master = MockMaster()
component = MockComponent(master, master.spec.component[0])
component.network = mst_units.MstSolverNetwork(component)
stride = 2
lengths = tf.constant([[2], [3]], dtype=tf.int64)
pad = -12345.6
scores = tf.constant([[1.0, 2.0, pad],
[1.8, 2.0, pad],
[pad, pad, pad],
[3.8, 4.0, 3.9],
[3.9, 3.8, 4.0],
[3.8, 0.9, 4.0]],
dtype=tf.float32) # pyformat: disable
linked_embeddings = [
network_units.NamedTensor(lengths, 'lengths'),
network_units.NamedTensor(scores, 'scores')
]
network_tensors = component.network.create([], linked_embeddings, [],
None, False, stride)
predictions = component.network.get_bulk_predictions(
stride, network_tensors)
self.assertAllEqual(predictions.eval(),
[[0.0, 1.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, 1.0]]) # pyformat: disable
def testComputeBulkLossM3n(self):
with self.test_session():
master = MockMaster()
component = MockComponent(master, master.spec.component[0])
component.spec.network_unit.parameters['loss'] = 'm3n'
component.network = mst_units.MstSolverNetwork(component)
stride = 2
lengths = tf.constant([[2], [3]], dtype=tf.int64)
# Note that these scores are large enough to overcome the +1 hamming loss
# terms in the M3N loss. Therefore, the score matrix determines the tree
# that is used to compute the M3N loss.
pad = -12345.6
scores = tf.constant([[0.5, 2.0, pad],
[0.5, 2.0, pad],
[pad, pad, pad],
[2.5, 4.0, 2.5],
[2.5, 2.5, 4.0],
[2.5, 2.5, 4.0]],
dtype=tf.float32) # pyformat: disable
# For the first tree, the gold and scores agree on one arc (that index 1
# is a root), and for the second tree, the gold and scores agree on none
# of the arcs. Therefore, we expect +1 and +3 for the first and second
# trees in the M3N loss.
gold = tf.constant([0, 1, -1, 0, 0, 1], tf.int32)
first_gold_score = 0.5 + 2.0
second_gold_score = 2.5 + 2.5 + 2.5
first_tree_correct = 1
second_tree_correct = 0
first_tree_loss = 2 * 2.0 + 2 - first_tree_correct - first_gold_score
second_tree_loss = 3 * 4.0 + 3 - second_tree_correct - second_gold_score
linked_embeddings = [
network_units.NamedTensor(lengths, 'lengths'),
network_units.NamedTensor(scores, 'scores')
]
network_tensors = component.network.create([], linked_embeddings, [],
None, False, stride)
cost, correct, total = component.network.compute_bulk_loss(
stride, network_tensors, gold)
self.assertEqual(cost.eval(), first_tree_loss + second_tree_loss)
self.assertEqual(correct.eval(), first_tree_correct + second_tree_correct)
self.assertEqual(total.eval(), 2 + 3)
def testComputeBulkLossCrf(self):
with self.test_session():
master = MockMaster()
component = MockComponent(master, master.spec.component[0])
component.spec.network_unit.parameters['loss'] = 'crf'
component.network = mst_units.MstSolverNetwork(component)
stride = 2
lengths = tf.constant([[2], [3]], dtype=tf.int64)
# These scores have 2.0 (in the log domain) on the gold arcs and 1.0
# elsewhere.
pad = -12345.6
one = math.log(1.0)
two = math.log(2.0)
scores = tf.constant([[one, two, pad],
[one, two, pad],
[pad, pad, pad],
[one, two, one],
[one, one, two],
[one, one, two]],
dtype=tf.float32) # pyformat: disable
gold = tf.constant([1, 1, -1, 1, 2, 2], tf.int32)
first_partition_function = (
2.0 * 2.0 + # 0 -> 1 (gold)
1.0 * 1.0) # 1 -> 0
first_loss = -math.log(2.0 * 2.0 / first_partition_function)
second_partition_function = (
2.0 * 2.0 * 2.0 + # 0 -> 1 -> 2 (gold)
1.0 * 1.0 * 1.0 + # 2 -> 1 -> 0
1.0 * 1.0 * 1.0 + # 0 -> 2 -> 1
2.0 * 1.0 * 1.0 + # 1 -> 2 -> 0
2.0 * 1.0 * 1.0 + # 1 -> 0 -> 2
2.0 * 1.0 * 1.0 + # 2 -> 0 -> 1
2.0 * 2.0 * 1.0 + # {0, 1} -> 2
2.0 * 1.0 * 1.0 + # {0, 2} -> 1
1.0 * 1.0 * 1.0) # {1, 2} -> 0
second_loss = -math.log(2.0 * 2.0 * 2.0 / second_partition_function)
linked_embeddings = [
network_units.NamedTensor(lengths, 'lengths'),
network_units.NamedTensor(scores, 'scores')
]
network_tensors = component.network.create([], linked_embeddings, [],
None, False, stride)
cost, correct, total = component.network.compute_bulk_loss(
stride, network_tensors, gold)
self.assertAlmostEqual(cost.eval(), first_loss + second_loss)
self.assertEqual(correct.eval(), 2 + 3)
self.assertEqual(total.eval(), 2 + 3)
if __name__ == '__main__':
tf.test.main()
......@@ -22,7 +22,6 @@ import abc
import numpy as np
from six.moves import xrange
import tensorflow as tf
from tensorflow.python.ops import nn
from tensorflow.python.ops import tensor_array_ops as ta
......@@ -76,11 +75,13 @@ class StoredActivations(object):
check.NotNone(dim, 'Dim is required for bulk tensor')
self._bulk_tensor = tensor
with tf.name_scope('convert_to_dyn'):
tensor = tf.reshape(tensor, [stride, -1, dim])
tensor = tf.transpose(tensor, perm=[1, 0, 2])
pad = tf.zeros([1, stride, dim], dtype=tensor.dtype)
self._array_tensor = tf.concat([pad, tensor], 0)
if dim >= 0:
# These operations will fail if |dim| is negative.
with tf.name_scope('convert_to_dyn'):
tensor = tf.reshape(tensor, [stride, -1, dim])
tensor = tf.transpose(tensor, perm=[1, 0, 2])
pad = tf.zeros([1, stride, dim], dtype=tensor.dtype)
self._array_tensor = tf.concat([pad, tensor], 0)
if array is not None:
check.IsNone(tensor, 'Cannot initialize from both tensor and array')
......@@ -130,7 +131,8 @@ def add_embeddings(channel_id, feature_spec, seed=None):
check.Gt(feature_spec.embedding_dim, 0,
'Embeddings requested for non-embedded feature: %s' % feature_spec)
name = fixed_embeddings_name(channel_id)
shape = [feature_spec.vocabulary_size + 1, feature_spec.embedding_dim]
row_num = feature_spec.vocabulary_size + 1
shape = [row_num, feature_spec.embedding_dim]
if feature_spec.HasField('pretrained_embedding_matrix'):
if len(feature_spec.pretrained_embedding_matrix.part) > 1:
raise RuntimeError('pretrained_embedding_matrix resource contains '
......@@ -143,9 +145,9 @@ def add_embeddings(channel_id, feature_spec, seed=None):
embeddings = syntaxnet_ops.word_embedding_initializer(
vectors=feature_spec.pretrained_embedding_matrix.part[0].file_pattern,
vocabulary=feature_spec.vocab.part[0].file_pattern,
override_num_embeddings=row_num,
num_special_embeddings=1,
embedding_init=1.0,
embedding_init=0.0, # zero out rows with no pretrained values
seed=seed1,
seed2=seed2)
return tf.get_variable(
......@@ -183,7 +185,57 @@ def embedding_lookup(embedding_matrix, indices, ids, weights, size):
return embeddings
def fixed_feature_lookup(component, state, channel_id, stride):
def apply_feature_id_dropout(ids, weights, channel):
"""Randomly perturbs a vector of feature IDs.
Args:
ids: Vector of feature IDs.
weights: Vector of feature weights.
channel: FixedFeatureChannel that extracted the |ids|.
Returns:
Copy of |ids| and |weights| where each ID is randomly replaced with
|channel.dropout_id|, according to the probabilities in
|channel.dropout_keep_probabilities|. The weights of dropped features are
set to zero if |channel.dropped_id| equals |channel.vocabulary_size|.
"""
check.Gt(
len(channel.dropout_keep_probability), 0,
'Channel {} dropout_keep_probability is empty'.format(channel.name))
check.Le(
len(channel.dropout_keep_probability), channel.vocabulary_size,
'Channel {} dropout_keep_probability is too long'.format(channel.name))
# Channel fields, converted from proto to constant tensor.
dropout_id = tf.constant(
channel.dropout_id, name='dropout_id', dtype=tf.int64)
dropout_keep_probabilities = tf.constant(
list(channel.dropout_keep_probability),
name='dropout_keep_probability',
dtype=tf.float32,
shape=[channel.vocabulary_size])
# The keep probabilities for the current batch of feature IDs.
keep_probabilities = tf.gather(dropout_keep_probabilities, ids)
# Draw random values and determine which IDs should be kept.
shape = tf.shape(ids)
noise = tf.random_uniform(shape) # \in [0,1)^d
should_keep = noise < keep_probabilities
# Replace dropped IDs with the specified replacement ID.
dropout_ids = tf.fill(shape, dropout_id)
new_ids = tf.where(should_keep, ids, dropout_ids)
if channel.dropout_id == channel.vocabulary_size:
# Replace weights of dropped IDs with 0.
zeros = tf.zeros(shape, dtype=tf.float32)
new_weights = tf.where(should_keep, weights, zeros)
else:
new_weights = weights
return new_ids, new_weights
def fixed_feature_lookup(component, state, channel_id, stride, during_training):
"""Looks up fixed features and passes them through embeddings.
Embedding vectors may be scaled by weights if the features specify it.
......@@ -193,6 +245,8 @@ def fixed_feature_lookup(component, state, channel_id, stride):
state: MasterState object for the live ComputeSession.
channel_id: int id of the fixed feature to look up.
stride: int Tensor of current batch * beam size.
during_training: True if this is being called from a training code path.
This controls, e.g., the use of feature ID dropout.
Returns:
NamedTensor object containing the embedding vectors.
......@@ -200,13 +254,35 @@ def fixed_feature_lookup(component, state, channel_id, stride):
feature_spec = component.spec.fixed_feature[channel_id]
check.Gt(feature_spec.embedding_dim, 0,
'Embeddings requested for non-embedded feature: %s' % feature_spec)
embedding_matrix = component.get_variable(fixed_embeddings_name(channel_id))
if feature_spec.is_constant:
embedding_matrix = tf.get_variable(fixed_embeddings_name(channel_id))
else:
embedding_matrix = component.get_variable(fixed_embeddings_name(channel_id))
with tf.op_scope([embedding_matrix], 'fixed_embedding_' + feature_spec.name):
indices, ids, weights = dragnn_ops.extract_fixed_features(
state.handle, component=component.name, channel_id=channel_id)
size = stride * feature_spec.size
embeddings = embedding_lookup(embedding_matrix, indices, ids, weights, size)
if during_training and feature_spec.dropout_id >= 0:
ids, weights = apply_feature_id_dropout(ids, weights, feature_spec)
if component.master.build_runtime_graph:
# To simplify integration with NN compilers, assume that each feature in
# the channel extracts exactly one ID and no weights.
# TODO(googleuser): Relax this restriction?
embeddings = []
for index in range(feature_spec.size):
feature_id = component.add_cell_input(
tf.int32, [1], 'fixed_channel_{}_index_{}_ids'.format(
channel_id, index))
embeddings.append(tf.gather(embedding_matrix, feature_id))
embeddings = tf.concat(embeddings, 1)
else:
size = stride * feature_spec.size
embeddings = embedding_lookup(embedding_matrix, indices, ids, weights,
size)
dim = feature_spec.size * feature_spec.embedding_dim
return NamedTensor(
tf.reshape(embeddings, [-1, dim]), feature_spec.name, dim=dim)
......@@ -368,12 +444,16 @@ def convert_network_state_tensorarray(tensorarray):
return tf.reshape(tensor, [-1, tf.shape(tensor)[2]])
def pass_through_embedding_matrix(act_block, embedding_matrix, step_idx):
def pass_through_embedding_matrix(component, channel_id, size, act_block,
embedding_matrix, step_idx):
"""Passes the activations through the embedding_matrix.
Takes care to handle out of bounds lookups.
Args:
component: Component that produced the linked features.
channel_id: Channel that produced the linked features.
size: Number of linked embeddings in the channel.
act_block: matrix of activations.
embedding_matrix: matrix of weights.
step_idx: vector containing step indices, with -1 indicating out of bounds.
......@@ -383,14 +463,36 @@ def pass_through_embedding_matrix(act_block, embedding_matrix, step_idx):
"""
# Indicator vector for out of bounds lookups.
step_idx_mask = tf.expand_dims(tf.equal(step_idx, -1), -1)
step_idx_mask = tf.to_float(step_idx_mask)
if component.master.build_runtime_graph:
step_idx_mask = component.add_cell_input(
step_idx_mask.dtype, [size, 1],
'linked_channel_{}_out_of_bounds'.format(channel_id))
# Pad the last column of the activation vectors with the indicator.
act_block = tf.concat([act_block, tf.to_float(step_idx_mask)], 1)
act_block = tf.concat([act_block, step_idx_mask], 1)
return tf.matmul(act_block, embedding_matrix)
def lookup_named_tensor_or_none(name, named_tensors):
"""Retrieves a NamedTensor by name, or None if it doesn't exist.
Args:
name: Name of the tensor to retrieve.
named_tensors: List of NamedTensor objects to search.
Returns:
The NamedTensor in |named_tensors| with the |name| or None.
"""
for named_tensor in named_tensors:
if named_tensor.name == name:
return named_tensor
return None
def lookup_named_tensor(name, named_tensors):
"""Retrieves a NamedTensor by name.
"""Retrieves a NamedTensor by name, raising KeyError if it doesn't exist.
Args:
name: Name of the tensor to retrieve.
......@@ -402,11 +504,11 @@ def lookup_named_tensor(name, named_tensors):
Raises:
KeyError: If the |name| is not found among the |named_tensors|.
"""
for named_tensor in named_tensors:
if named_tensor.name == name:
return named_tensor
raise KeyError('Name "%s" not found in named tensors: %s' % (name,
named_tensors))
result = lookup_named_tensor_or_none(name, named_tensors)
if result is None:
raise KeyError('Name "%s" not found in named tensors: %s' % (name,
named_tensors))
return result
def activation_lookup_recurrent(component, state, channel_id, source_array,
......@@ -417,9 +519,9 @@ def activation_lookup_recurrent(component, state, channel_id, source_array,
not passed through (i.e. multiplied by) an embedding matrix.
Args:
component: Component object in which to look up the fixed features.
component: Component object in which to look up the linked features.
state: MasterState object for the live ComputeSession.
channel_id: int id of the fixed feature to look up.
channel_id: int id of the linked feature to look up.
source_array: TensorArray from which to fetch feature vectors, expected to
have size [steps + 1] elements of shape [stride, D] each.
source_layer_size: int length of feature vectors before embedding.
......@@ -459,11 +561,17 @@ def activation_lookup_recurrent(component, state, channel_id, source_array,
act_block = tf.gather(act_block, flat_idx)
act_block = tf.reshape(act_block, [-1, source_layer_size])
if component.master.build_runtime_graph:
act_block = component.add_cell_input(act_block.dtype, [
feature_spec.size, source_layer_size
], 'linked_channel_{}_activations'.format(channel_id))
if feature_spec.embedding_dim != -1:
embedding_matrix = component.get_variable(
linked_embeddings_name(channel_id))
act_block = pass_through_embedding_matrix(act_block, embedding_matrix,
step_idx)
act_block = pass_through_embedding_matrix(component, channel_id,
feature_spec.size, act_block,
embedding_matrix, step_idx)
dim = feature_spec.size * feature_spec.embedding_dim
else:
# If embedding_dim is -1, just output concatenation of activations.
......@@ -481,9 +589,9 @@ def activation_lookup_other(component, state, channel_id, source_tensor,
not passed through (i.e. multiplied by) an embedding matrix.
Args:
component: Component object in which to look up the fixed features.
component: Component object in which to look up the linked features.
state: MasterState object for the live ComputeSession.
channel_id: int id of the fixed feature to look up.
channel_id: int id of the linked feature to look up.
source_tensor: Tensor from which to fetch feature vectors. Expected to have
have shape [steps + 1, stride, D].
source_layer_size: int length of feature vectors before embedding (D). It
......@@ -508,11 +616,17 @@ def activation_lookup_other(component, state, channel_id, source_tensor,
act_block = tf.gather_nd(source_tensor, indices)
act_block = tf.reshape(act_block, [-1, source_layer_size])
if component.master.build_runtime_graph:
act_block = component.add_cell_input(act_block.dtype, [
feature_spec.size, source_layer_size
], 'linked_channel_{}_activations'.format(channel_id))
if feature_spec.embedding_dim != -1:
embedding_matrix = component.get_variable(
linked_embeddings_name(channel_id))
act_block = pass_through_embedding_matrix(act_block, embedding_matrix,
step_idx)
act_block = pass_through_embedding_matrix(component, channel_id,
feature_spec.size, act_block,
embedding_matrix, step_idx)
dim = feature_spec.size * feature_spec.embedding_dim
else:
# If embedding_dim is -1, just output concatenation of activations.
......@@ -629,7 +743,7 @@ class Layer(object):
Returns:
TensorArray object
"""
check.Gt(self.dim, 0, 'Cannot create array when dimension is dynamic')
check.Ge(self.dim, 0, 'Cannot create array when dimension is dynamic')
tensor_array = ta.TensorArray(
dtype=tf.float32,
size=0,
......@@ -671,7 +785,19 @@ def get_attrs_with_defaults(parameters, defaults):
return attrs
def maybe_apply_dropout(inputs, keep_prob, per_sequence, stride=None):
def maybe_make_dropout_mask(shape, keep_prob):
"""Returns a reusable dropout mask, or None if dropout would not occur."""
if keep_prob >= 1.0:
return None
return tf.nn.dropout(tf.ones(shape, dtype=tf.float32), keep_prob)
def maybe_apply_dropout(inputs,
keep_prob,
per_sequence,
stride=None,
dropout_mask=None,
name=None):
"""Applies dropout, if so configured, to an input tensor.
The input may be rank 2 or 3 depending on whether the stride (i.e., batch
......@@ -682,20 +808,27 @@ def maybe_apply_dropout(inputs, keep_prob, per_sequence, stride=None):
keep_prob: Scalar probability of keeping each input element. If >= 1.0, no
dropout is performed.
per_sequence: If true, sample the dropout mask once per sequence, instead of
once per step. Requires |stride| when true.
stride: Scalar batch size. Optional if |per_sequence| is false.
once per step. Either |stride| or |dropout_mask| must be set when true.
stride: Scalar batch size. Optional if |per_sequence| is false, or if
|dropout_mask| is provided.
dropout_mask: Precomputed dropout mask to apply to the |inputs|; must be
broadcastable to |inputs|. Optional if |per_sequence| is false, or if
|stride| is provided.
name: Optional name for the dropout operation, if dropout is applied.
Returns:
[stride * num_steps, dim] or [stride, num_steps, dim] tensor, matching the
shape of |inputs|, containing the masked or original inputs, depending on
whether dropout was actually performed.
"""
if keep_prob >= 1.0:
return inputs
if not per_sequence:
return tf.nn.dropout(inputs, keep_prob)
return tf.nn.dropout(inputs, keep_prob, name=name)
if dropout_mask is not None:
return tf.multiply(inputs, dropout_mask, name=name)
# We only check the dims if we are applying per-sequence dropout
check.Ge(inputs.get_shape().ndims, 2, 'inputs must be rank 2 or 3')
......@@ -713,7 +846,7 @@ def maybe_apply_dropout(inputs, keep_prob, per_sequence, stride=None):
# Replace |num_steps| with 1 in |noise_shape|, so the dropout mask broadcasts
# to all steps for a particular sequence.
noise_shape = [stride, 1, dim]
masked_sxnxd = tf.nn.dropout(inputs_sxnxd, keep_prob, noise_shape)
masked_sxnxd = tf.nn.dropout(inputs_sxnxd, keep_prob, noise_shape, name=name)
# If needed, flatten out the batch dimension in the return value.
return tf.reshape(masked_sxnxd, [-1, dim]) if flat else masked_sxnxd
......@@ -749,6 +882,7 @@ class NetworkUnitInterface(object):
"""
self._component = component
self._params = []
self._derived_params = []
self._layers = init_layers if init_layers else []
self._regularized_weights = []
self._context_layers = init_context_layers if init_context_layers else []
......@@ -764,7 +898,10 @@ class NetworkUnitInterface(object):
check.Gt(spec.size, 0, 'Invalid fixed feature size')
if spec.embedding_dim > 0:
fixed_dim = spec.embedding_dim
self._params.append(add_embeddings(channel_id, spec))
if spec.is_constant:
add_embeddings(channel_id, spec)
else:
self._params.append(add_embeddings(channel_id, spec))
else:
fixed_dim = 1 # assume feature ID extraction; only one ID per step
self._fixed_feature_dims[spec.name] = spec.size * fixed_dim
......@@ -802,8 +939,8 @@ class NetworkUnitInterface(object):
self._concatenated_input_dim = -1
else:
self._concatenated_input_dim = sum(input_dims)
tf.logging.info('component %s concat_input_dim %s', component.name,
self._concatenated_input_dim)
tf.logging.debug('component %s concat_input_dim %s', component.name,
self._concatenated_input_dim)
# Allocate attention parameters.
if self._component.spec.attention_component:
......@@ -845,6 +982,19 @@ class NetworkUnitInterface(object):
[attention_hidden_layer_size, component.num_actions],
initializer=tf.random_normal_initializer(stddev=1e-4)))
def pre_create(self, stride):
"""Prepares this network for inputs of the given stride.
This will be called before entering the main transition loop and calling
create(). Networks can use this to pre-compute values that are reused in
the main transition loop. Note that this may be called multiple times;
e.g., once for the training graph, and again for the inference graph.
Args:
stride: Scalar batch_size * beam_size.
"""
pass
@abc.abstractmethod
def create(self,
fixed_embeddings,
......@@ -878,6 +1028,18 @@ class NetworkUnitInterface(object):
def params(self):
return self._params
@property
def derived_params(self):
"""Gets the list of derived parameters.
Derived parameters are similar to `params`, but reformatted slightly
(because doing so is easier in Python).
Returns:
List of zero-argument getters, each of which return a tensor when called.
"""
return self._derived_params
@property
def regularized_weights(self):
return self._regularized_weights
......@@ -919,6 +1081,38 @@ class NetworkUnitInterface(object):
"""
raise NotImplementedError()
def get_bulk_predictions(self, stride, network_tensors):
"""Returns custom bulk predictions, if supported.
The returned predictions will be used to advance the batch of states, like
logits. For example, a network may perform structured prediction, and then
return 0/1 indicators of the jointly-predicted annotations. The difference
between this and get_logits() is that this is only used at inference time.
Args:
stride: Scalar stride for segmenting bulk tensors.
network_tensors: List of tensors as returned by create().
Returns:
[stride * steps, dim] matrix of predictions, or None if not supported.
"""
del stride, network_tensors
return None
def compute_bulk_loss(self, stride, network_tensors, gold):
"""Returns a custom bulk training loss, if supported.
Args:
stride: Scalar stride for segmenting bulk tensors.
network_tensors: List of tensors as returned by create().
gold: [stride * steps] vector of gold actions.
Returns:
Tuple of (loss, correct, total), or (None, None, None) if not supported.
"""
del stride, network_tensors, gold
return (None, None, None)
def get_l2_regularized_weights(self):
"""Gets the weights that need to be regularized."""
return self.regularized_weights
......@@ -1026,6 +1220,12 @@ class FeedForwardNetwork(NetworkUnitInterface):
(https://arxiv.org/abs/1512.05287).
dropout_all_layers (False): If true, apply dropout to the input of all
hidden layers, instead of just applying it to the network input.
initialize_bias_zero (False): If true, initialize bias vectors to 0.
Otherwise, they are initialized to a small constant value.
initialize_softmax_zero (False): If true, initialize softmax weights to 0.
Otherwise, they are initialized to small random values.
initialize_hidden_orthogonal (False): If true, initialize hidden weights
orthogonally. Otherwise, they are initialized to small random values.
Hyperparameters used:
dropout_rate: The probability that an input is not dropped. Only used
......@@ -1041,9 +1241,25 @@ class FeedForwardNetwork(NetworkUnitInterface):
'nonlinearity': 'relu',
'dropout_keep_prob': -1.0,
'dropout_per_sequence': False,
'dropout_all_layers': False
'dropout_all_layers': False,
'initialize_bias_zero': False,
'initialize_softmax_zero': False,
'initialize_hidden_orthogonal': False,
})
def _make_bias_initializer():
return (tf.zeros_initializer() if self._attrs['initialize_bias_zero'] else
tf.constant_initializer(0.2, dtype=tf.float32))
def _make_softmax_initializer():
return (tf.zeros_initializer() if self._attrs['initialize_softmax_zero']
else tf.random_normal_initializer(stddev=1e-4))
def _make_hidden_initializer():
return (tf.orthogonal_initializer()
if self._attrs['initialize_hidden_orthogonal'] else
tf.random_normal_initializer(stddev=1e-4))
# Initialize the hidden layer sizes before running the base initializer, as
# the base initializer may need to know the size of the hidden layer for
# recurrent connections.
......@@ -1084,13 +1300,13 @@ class FeedForwardNetwork(NetworkUnitInterface):
for index, hidden_layer_size in enumerate(self._hidden_layer_sizes):
weights = tf.get_variable(
'weights_%d' % index, [last_layer_dim, hidden_layer_size],
initializer=tf.random_normal_initializer(stddev=1e-4))
initializer=_make_hidden_initializer())
self._params.append(weights)
if index > 0 or self._layer_norm_hidden is None:
self._params.append(
tf.get_variable(
'bias_%d' % index, [hidden_layer_size],
initializer=tf.constant_initializer(0.2, dtype=tf.float32)))
initializer=_make_bias_initializer()))
self._weights.append(weights)
self._layers.append(
......@@ -1108,7 +1324,7 @@ class FeedForwardNetwork(NetworkUnitInterface):
self._params.append(
tf.get_variable(
'weights_softmax', [last_layer_dim, component.num_actions],
initializer=tf.random_normal_initializer(stddev=1e-4)))
initializer=_make_softmax_initializer()))
self._params.append(
tf.get_variable(
'bias_softmax', [component.num_actions],
......@@ -1199,67 +1415,133 @@ class FeedForwardNetwork(NetworkUnitInterface):
class LSTMNetwork(NetworkUnitInterface):
"""Implementation of action LSTM style network."""
"""Implementation of action LSTM style network.
Note that this is not a vanilla LSTM: it adds peephole connections and couples
the input and forget gates.
This implementation treats linked features called lstm_h and lstm_c specially.
Instead of treating them as normal linked features, it uses them as the
previous LSTM states. This allows having a single LSTM component actually
consist of several LSTMs, or to have a tree-shaped LSTM.
"""
def __init__(self, component):
"""Initializes LSTM parameters.
Args:
component: parent ComponentBuilderBase object.
Parameters used to construct the network:
hidden_layer_sizes: In spite of its name, a single int indicating the
number of hidden units in each hidden layer.
factored_hidden_dim: If positive, the weight matrix is factored into a
product of two matrices with this inner dimension.
omit_logits (False): Whether to elide the logits layer.
initialize_bias_zero (False): If true, initialize bias vectors to 0.
Otherwise, they are initialized to small random values.
initialize_softmax_zero (False): If true, initialize softmax weights to 0.
Otherwise, they are initialized to small random values.
initialize_hidden_orthogonal (False): If true, initialize hidden weights
orthogonally. Otherwise, they are initialized to small random values.
input_dropout_rate (-1.0): Keep probability for inputs. If negative, fall
back to the |dropout_rate| hyperparameter.
recurrent_dropout_rate (-1.0): Keep probability for recurrences. If
negative, fall back to the |recurrent_dropout_rate| hyperparameter.
dropout_per_sequence (False): If true, sample the dropout mask once per
sequence, instead of once per step. See Gal and Ghahramani
(https://arxiv.org/abs/1512.05287).
"""
assert component.num_actions > 0, 'Component num actions must be positive.'
network_unit_spec = component.spec.network_unit
self._hidden_layer_sizes = (
int)(network_unit_spec.parameters['hidden_layer_sizes'])
self._attrs = get_attrs_with_defaults(
component.spec.network_unit.parameters,
defaults={
'hidden_layer_sizes': -1, # NB: a single dim, not a list
'factored_hidden_dim': -1,
'omit_logits': False,
'initialize_bias_zero': False,
'initialize_softmax_zero': False,
'initialize_hidden_orthogonal': False,
'input_dropout_rate': -1.0,
'recurrent_dropout_rate': -1.0,
'dropout_per_sequence': False,
})
def _make_bias_initializer():
return (tf.zeros_initializer() if self._attrs['initialize_bias_zero'] else
tf.random_normal_initializer(stddev=1e-4))
self._input_dropout_rate = component.master.hyperparams.dropout_rate
self._recurrent_dropout_rate = (
component.master.hyperparams.recurrent_dropout_rate)
def _make_softmax_initializer():
return (tf.zeros_initializer() if self._attrs['initialize_softmax_zero']
else tf.random_normal_initializer(stddev=1e-4))
self._hidden_layer_sizes = self._attrs['hidden_layer_sizes']
self._factored_hidden_dim = self._attrs['factored_hidden_dim']
self._compute_logits = not self._attrs['omit_logits']
self._dropout_per_sequence = self._attrs['dropout_per_sequence']
self._input_dropout_rate = self._attrs['input_dropout_rate']
if self._input_dropout_rate < 0.0:
self._input_dropout_rate = component.master.hyperparams.dropout_rate
self._recurrent_dropout_rate = self._attrs['recurrent_dropout_rate']
if self._recurrent_dropout_rate < 0.0:
self._recurrent_dropout_rate = (
component.master.hyperparams.recurrent_dropout_rate)
if self._recurrent_dropout_rate < 0.0:
self._recurrent_dropout_rate = component.master.hyperparams.dropout_rate
tf.logging.info('[%s] dropout: input=%s recurrent=%s per_sequence=%s',
component.name, self._input_dropout_rate,
self._recurrent_dropout_rate, self._dropout_per_sequence)
super(LSTMNetwork, self).__init__(component)
layer_input_dim = self._concatenated_input_dim
self._layer_input_dim = self._concatenated_input_dim
if self._layer_input_dim > 1:
for skipped_link in ['lstm_h', 'lstm_c']:
if skipped_link in self._linked_feature_dims:
self._layer_input_dim -= self._linked_feature_dims[skipped_link]
self._input_dropout_mask = None
self._recurrent_dropout_mask = None
self._context_layers = []
# TODO(googleuser): should we choose different initilizer,
# e.g. truncated_normal_initializer?
self._x2i = tf.get_variable(
'x2i', [layer_input_dim, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4))
self._h2i = tf.get_variable(
'h2i', [self._hidden_layer_sizes, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4))
self._c2i = tf.get_variable(
'c2i', [self._hidden_layer_sizes, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4))
self._bi = tf.get_variable(
'bi', [self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4))
self._x2o = tf.get_variable(
'x2o', [layer_input_dim, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4))
self._h2o = tf.get_variable(
'h2o', [self._hidden_layer_sizes, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4))
self._c2o = tf.get_variable(
'c2o', [self._hidden_layer_sizes, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4))
self._bo = tf.get_variable(
'bo', [self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4))
self._x2c = tf.get_variable(
'x2c', [layer_input_dim, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4))
self._h2c = tf.get_variable(
'h2c', [self._hidden_layer_sizes, self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4))
self._bc = tf.get_variable(
'bc', [self._hidden_layer_sizes],
initializer=tf.random_normal_initializer(stddev=1e-4))
self._params.extend([
self._x2i, self._h2i, self._c2i, self._bi, self._x2o, self._h2o,
self._c2o, self._bo, self._x2c, self._h2c, self._bc
])
self._create_hidden_weights(
'x2i', [self._layer_input_dim, self._hidden_layer_sizes])
self._create_hidden_weights(
'h2i', [self._hidden_layer_sizes, self._hidden_layer_sizes])
self._create_hidden_weights(
'c2i', [self._hidden_layer_sizes, self._hidden_layer_sizes])
self._params.append(
tf.get_variable(
'bi', [self._hidden_layer_sizes],
initializer=_make_bias_initializer()))
self._create_hidden_weights(
'x2o', [self._layer_input_dim, self._hidden_layer_sizes])
self._create_hidden_weights(
'h2o', [self._hidden_layer_sizes, self._hidden_layer_sizes])
self._create_hidden_weights(
'c2o', [self._hidden_layer_sizes, self._hidden_layer_sizes])
self._params.append(
tf.get_variable(
'bo', [self._hidden_layer_sizes],
initializer=_make_bias_initializer()))
self._create_hidden_weights(
'x2c', [self._layer_input_dim, self._hidden_layer_sizes])
self._create_hidden_weights(
'h2c', [self._hidden_layer_sizes, self._hidden_layer_sizes])
self._params.append(
tf.get_variable(
'bc', [self._hidden_layer_sizes],
initializer=_make_bias_initializer()))
# Add runtime hooks for combined matrices.
self._derived_params.append(self._get_x_to_ico)
self._derived_params.append(self._get_h_to_ico)
self._derived_params.append(self._get_ico_bias)
lstm_h_layer = Layer(component, name='lstm_h', dim=self._hidden_layer_sizes)
lstm_c_layer = Layer(component, name='lstm_c', dim=self._hidden_layer_sizes)
......@@ -1272,18 +1554,92 @@ class LSTMNetwork(NetworkUnitInterface):
self._layers.append(
Layer(component, name='layer_0', dim=self._hidden_layer_sizes))
self.params.append(
tf.get_variable(
'weights_softmax',
[self._hidden_layer_sizes, component.num_actions],
initializer=tf.random_normal_initializer(stddev=1e-4)))
self.params.append(
tf.get_variable(
'bias_softmax', [component.num_actions],
initializer=tf.zeros_initializer()))
if self._compute_logits:
self.params.append(
tf.get_variable(
'weights_softmax',
[self._hidden_layer_sizes, component.num_actions],
initializer=_make_softmax_initializer()))
self.params.append(
tf.get_variable(
'bias_softmax', [component.num_actions],
initializer=tf.zeros_initializer()))
self._layers.append(
Layer(component, name='logits', dim=component.num_actions))
self._layers.append(
Layer(component, name='logits', dim=component.num_actions))
def _get_variable_name_prefix(self):
"""Returns the prefix for variable names."""
# The bias variables are always present; infer the prefix from one of them.
bi = self._component.get_variable('bi')
tokens = bi.op.name.split('/')
while tokens.pop() != 'bi':
pass # remove the last 'bi' and everything after it
return '/'.join(tokens) + '/'
def _get_x_to_ico(self):
# TODO(googleuser): Export the factored representation, if available.
x2i = self._multiply_hidden_weights(tf.eye(self._layer_input_dim), 'x2i')
x2c = self._multiply_hidden_weights(tf.eye(self._layer_input_dim), 'x2c')
x2o = self._multiply_hidden_weights(tf.eye(self._layer_input_dim), 'x2o')
prefix = self._get_variable_name_prefix()
with tf.name_scope(None):
return tf.concat([x2i, x2c, x2o], axis=1, name=prefix + 'x_to_ico')
def _get_h_to_ico(self):
# TODO(googleuser): Export the factored representation, if available.
h2i = self._multiply_hidden_weights(tf.eye(self._hidden_layer_sizes), 'h2i')
h2c = self._multiply_hidden_weights(tf.eye(self._hidden_layer_sizes), 'h2c')
h2o = self._multiply_hidden_weights(tf.eye(self._hidden_layer_sizes), 'h2o')
prefix = self._get_variable_name_prefix()
with tf.name_scope(None):
return tf.concat([h2i, h2c, h2o], axis=1, name=prefix + 'h_to_ico')
def _get_ico_bias(self):
bi = self._component.get_variable('bi')
bc = self._component.get_variable('bc')
bo = self._component.get_variable('bo')
prefix = self._get_variable_name_prefix()
with tf.name_scope(None):
return tf.concat([bi, bc, bo], axis=0, name=prefix + 'ico_bias')
def _create_hidden_weights(self, name, shape):
"""Creates params for hidden weight matrix of the given shape."""
check.Eq(len(shape), 2, 'Hidden weights %s must be a matrix' % name)
def _initializer():
return (tf.orthogonal_initializer()
if self._attrs['initialize_hidden_orthogonal'] else
tf.random_normal_initializer(stddev=1e-4))
if self._factored_hidden_dim > 0:
self._params.append(
tf.get_variable(
'%s_in' % name, [shape[0], self._factored_hidden_dim],
initializer=_initializer()))
self._params.append(
tf.get_variable(
'%s_out' % name, [self._factored_hidden_dim, shape[1]],
initializer=_initializer()))
else:
self._params.append(
tf.get_variable(name, shape, initializer=_initializer()))
def _multiply_hidden_weights(self, inputs, name):
"""Multiplies the inputs with the named hidden weight matrix."""
if self._factored_hidden_dim > 0:
inputs = tf.matmul(inputs, self._component.get_variable('%s_in' % name))
return tf.matmul(inputs, self._component.get_variable('%s_out' % name))
else:
return tf.matmul(inputs, self._component.get_variable(name))
def pre_create(self, stride):
"""Refreshes the dropout masks, if applicable."""
if self._dropout_per_sequence:
self._input_dropout_mask = maybe_make_dropout_mask(
[stride, self._layer_input_dim], self._input_dropout_rate)
self._recurrent_dropout_mask = maybe_make_dropout_mask(
[stride, self._hidden_layer_sizes], self._recurrent_dropout_rate)
def create(self,
fixed_embeddings,
......@@ -1293,51 +1649,84 @@ class LSTMNetwork(NetworkUnitInterface):
during_training,
stride=None):
"""See base class."""
input_tensor = get_input_tensor(fixed_embeddings, linked_embeddings)
# context_tensor_arrays[0] is lstm_h
# context_tensor_arrays[1] is lstm_c
assert len(context_tensor_arrays) == 2
length = context_tensor_arrays[0].size()
# Get the (possibly averaged) parameters to execute the network.
x2i = self._component.get_variable('x2i')
h2i = self._component.get_variable('h2i')
c2i = self._component.get_variable('c2i')
# Get the (possibly averaged) biases to execute the network.
bi = self._component.get_variable('bi')
x2o = self._component.get_variable('x2o')
h2o = self._component.get_variable('h2o')
c2o = self._component.get_variable('c2o')
bo = self._component.get_variable('bo')
x2c = self._component.get_variable('x2c')
h2c = self._component.get_variable('h2c')
bc = self._component.get_variable('bc')
if self._compute_logits:
weights_softmax = self._component.get_variable('weights_softmax')
bias_softmax = self._component.get_variable('bias_softmax')
i_h_tm1 = lookup_named_tensor_or_none('lstm_h', linked_embeddings)
h_from_linked = False
if i_h_tm1 is not None:
h_from_linked = True
i_h_tm1 = i_h_tm1.tensor
i_c_tm1 = lookup_named_tensor_or_none('lstm_c', linked_embeddings)
c_from_linked = False
if i_c_tm1 is not None:
c_from_linked = True
i_c_tm1 = i_c_tm1.tensor
# i_h_tm1, i_c_tm1 = h_{t-1}, c_{t-1} and label c and h inputs
if i_h_tm1 is None:
i_h_tm1 = context_tensor_arrays[0].read(length - 1)
if i_c_tm1 is None:
i_c_tm1 = context_tensor_arrays[1].read(length - 1)
i_h_tm1 = tf.identity(i_h_tm1, name='lstm_h_in')
i_c_tm1 = tf.identity(i_c_tm1, name='lstm_c_in')
# i_h_tm1, i_c_tm1 = h_{t-1}, c_{t-1}
i_h_tm1 = context_tensor_arrays[0].read(length - 1)
i_c_tm1 = context_tensor_arrays[1].read(length - 1)
# Add hard-coded recurrent inputs to the exported cell.
if self._component.master.build_runtime_graph:
shape = [1, self._hidden_layer_sizes]
if not c_from_linked:
i_c_tm1 = self._component.add_cell_input(i_c_tm1.dtype, shape, 'lstm_c',
'TYPE_RECURRENT')
if not h_from_linked:
i_h_tm1 = self._component.add_cell_input(i_h_tm1.dtype, shape, 'lstm_h',
'TYPE_RECURRENT')
# Remove 'lstm_h' and 'lstm_c' from linked_embeddings, since they are used
# in a special way.
linked_embeddings = [
x for x in linked_embeddings if x.name not in ['lstm_h', 'lstm_c']
]
# label c and h inputs
i_c_tm1 = tf.identity(i_c_tm1, name='lstm_c_in')
i_h_tm1 = tf.identity(i_h_tm1, name='lstm_h_in')
input_tensor = get_input_tensor(fixed_embeddings, linked_embeddings)
# label the feature input (for debugging purposes)
input_tensor = tf.identity(input_tensor, name='input_tensor')
# apply dropout according to http://arxiv.org/pdf/1409.2329v5.pdf
if during_training and self._input_dropout_rate < 1:
input_tensor = tf.nn.dropout(input_tensor, self._input_dropout_rate)
if during_training:
input_tensor = maybe_apply_dropout(
input_tensor,
self._input_dropout_rate,
self._dropout_per_sequence,
dropout_mask=self._input_dropout_mask)
# input -- i_t = sigmoid(affine(x_t, h_{t-1}, c_{t-1}))
i_ait = tf.matmul(input_tensor, x2i) + tf.matmul(i_h_tm1, h2i) + tf.matmul(
i_c_tm1, c2i) + bi
# Note peephole connection to previous cell state.
i_ait = (
self._multiply_hidden_weights(input_tensor, 'x2i') +
self._multiply_hidden_weights(i_h_tm1, 'h2i') +
self._multiply_hidden_weights(i_c_tm1, 'c2i') + bi)
i_it = tf.sigmoid(i_ait)
# forget -- f_t = 1 - i_t
# Note coupling with input gate.
i_ft = tf.ones([1, 1]) - i_it
# write memory cell -- tanh(affine(x_t, h_{t-1}))
i_awt = tf.matmul(input_tensor, x2c) + tf.matmul(i_h_tm1, h2c) + bc
i_awt = (
self._multiply_hidden_weights(input_tensor, 'x2c') +
self._multiply_hidden_weights(i_h_tm1, 'h2c') + bc)
i_wt = tf.tanh(i_awt)
# c_t = f_t \odot c_{t-1} + i_t \odot tanh(affine(x_t, h_{t-1}))
......@@ -1345,8 +1734,11 @@ class LSTMNetwork(NetworkUnitInterface):
tf.multiply(i_it, i_wt), tf.multiply(i_ft, i_c_tm1), name='lstm_c')
# output -- o_t = sigmoid(affine(x_t, h_{t-1}, c_t))
i_aot = tf.matmul(input_tensor, x2o) + tf.matmul(ct, c2o) + tf.matmul(
i_h_tm1, h2o) + bo
# Note peephole connection to current cell state.
i_aot = (
self._multiply_hidden_weights(input_tensor, 'x2o') +
self._multiply_hidden_weights(ct, 'c2o') +
self._multiply_hidden_weights(i_h_tm1, 'h2o') + bo)
i_ot = tf.sigmoid(i_aot)
......@@ -1354,27 +1746,35 @@ class LSTMNetwork(NetworkUnitInterface):
ph_t = tf.tanh(ct)
ht = tf.multiply(i_ot, ph_t, name='lstm_h')
if during_training and self._recurrent_dropout_rate < 1:
ht = tf.nn.dropout(
ht, self._recurrent_dropout_rate, name='lstm_h_dropout')
if during_training:
ht = maybe_apply_dropout(
ht,
self._recurrent_dropout_rate,
self._dropout_per_sequence,
dropout_mask=self._recurrent_dropout_mask,
name='lstm_h_dropout')
h = tf.identity(ht, name='layer_0')
logits = tf.nn.xw_plus_b(ht,
tf.get_variable('weights_softmax'),
tf.get_variable('bias_softmax'))
# tensors will be consistent with the layers:
# [lstm_h, lstm_c, layer_0, (optional) logits]
tensors = [ht, ct, h]
if self._component.spec.attention_component:
logits += self.attention(ht, attention_tensor)
if self._compute_logits:
logits = tf.nn.xw_plus_b(ht, weights_softmax, bias_softmax)
if self._component.spec.attention_component:
logits += self.attention(ht, attention_tensor)
logits = tf.identity(logits, name='logits')
tensors.append(logits)
logits = tf.identity(logits, name='logits')
# tensors will be consistent with the layers:
# [lstm_h, lstm_c, layer_0, logits]
tensors = [ht, ct, h, logits]
return tensors
def get_layer_size(self, layer_name):
assert layer_name == 'layer_0', 'Can only retrieve from first hidden layer.'
assert layer_name in {
'layer_0', 'lstm_h', 'lstm_c'
}, 'Can only retrieve from first hidden layer, lstm_h or lstm_c.'
return self._hidden_layer_sizes
def get_logits(self, network_tensors):
......@@ -1846,10 +2246,9 @@ class PairwiseConvNetwork(NetworkUnitInterface):
self._widths, self._dropout, self._bias_init, self._initialization
])
if not all(param_lengths[0] == param_len for param_len in param_lengths):
raise RuntimeError(
'Unmatched widths/dropout/bias_init/initialization: ' +
'%d/%d/%d/%d' % (param_lengths[0], param_lengths[1],
param_lengths[2], param_lengths[3]))
raise RuntimeError('Unmatched widths/dropout/bias_init/initialization: ' +
'%d/%d/%d/%d' % (param_lengths[0], param_lengths[1],
param_lengths[2], param_lengths[3]))
self._depths.extend(map(int, parameters['depths'].split(',')))
if len(self._depths) != len(self._widths) + 1:
......@@ -1866,9 +2265,8 @@ class PairwiseConvNetwork(NetworkUnitInterface):
self._num_labels = self._depths[-1]
if parameters['activation_layers']:
self._activation_layers = set(map(int,
parameters['activation_layers'].split(
',')))
self._activation_layers = set(
map(int, parameters['activation_layers'].split(',')))
else:
self._activation_layers = set(range(self._num_layers - 1))
......@@ -1876,7 +2274,7 @@ class PairwiseConvNetwork(NetworkUnitInterface):
for i, width in enumerate(self._widths):
if self._activation == 'glu' and i in self._activation_layers:
self._kernel_shapes.append(
[width, width, self._depths[i], 2*self._depths[i + 1]])
[width, width, self._depths[i], 2 * self._depths[i + 1]])
else:
self._kernel_shapes.append(
[width, width, self._depths[i], self._depths[i + 1]])
......@@ -1910,7 +2308,8 @@ class PairwiseConvNetwork(NetworkUnitInterface):
del context_tensor_arrays, attention_tensor # Unused.
# TODO(googleuser): Normalize the arguments to create(). 'stride'
# is unused by the recurrent network units, while 'context_tensor_arrays'
# and 'attenion_tensor_array' is unused by bulk network units. b/33587044
# and 'attenion_tensor_array' is unused by bulk network units.
if stride is None:
raise ValueError("PairwiseConvNetwork needs 'stride'")
......@@ -1926,8 +2325,9 @@ class PairwiseConvNetwork(NetworkUnitInterface):
sources_shape = tf.shape(source_tokens)
targets_shape = tf.shape(target_tokens)
num_steps = sources_shape[1]
with tf.control_dependencies([tf.assert_equal(num_steps, targets_shape[2],
name='num_steps_mismatch')]):
with tf.control_dependencies([
tf.assert_equal(num_steps, targets_shape[2], name='num_steps_mismatch')
]):
arg1 = tf.tile(source_tokens, tf.stack([1, 1, num_steps, 1]))
arg2 = tf.tile(target_tokens, tf.stack([1, num_steps, 1, 1]))
conv = tf.concat([arg1, arg2], 3)
......@@ -1935,10 +2335,10 @@ class PairwiseConvNetwork(NetworkUnitInterface):
with tf.variable_scope('conv%d' % i, reuse=True) as scope:
if during_training:
conv = maybe_apply_dropout(conv, self._dropout[i], False)
conv = tf.nn.conv2d(conv,
self._component.get_variable('weights'),
[1, 1, 1, 1],
padding='SAME')
conv = tf.nn.conv2d(
conv,
self._component.get_variable('weights'), [1, 1, 1, 1],
padding='SAME')
conv = tf.nn.bias_add(conv, self._component.get_variable('biases'))
if i in self._activation_layers:
conv = self._activation_fn(conv, name=scope.name)
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for network_units."""
......@@ -26,8 +25,6 @@ from tensorflow.python.platform import googletest
from dragnn.protos import spec_pb2
from dragnn.python import network_units
FLAGS = tf.app.flags.FLAGS
class NetworkUnitsConverterTest(test_util.TensorFlowTestCase):
......@@ -61,6 +58,7 @@ class MockComponent(object):
self.spec = component_spec
self.name = component_spec.name
self.beam_size = 1
self.num_actions = 45
self._attrs = {}
def attr(self, name):
......@@ -72,12 +70,13 @@ class MockComponent(object):
class MockMaster(object):
def __init__(self):
def __init__(self, build_runtime_graph=False):
self.spec = spec_pb2.MasterSpec()
self.hyperparams = spec_pb2.GridPoint()
self.lookup_component = {
'previous': MockComponent(self, spec_pb2.ComponentSpec())
}
self.build_runtime_graph = build_runtime_graph
class MockNetwork(object):
......@@ -167,6 +166,164 @@ class GetAttrsWithDefaultsTest(test_util.TensorFlowTestCase):
_assert_attr_is_true('TRUE')
class LstmNetworkTest(test_util.TensorFlowTestCase):
test_spec_1 = """
component {
name: 'bi_lstm'
backend { registered_name: 'TestComponent' }
fixed_feature {
name: 'words'
fml: 'words'
size: 1
embedding_dim: 32
vocabulary_size: 1079813,
}
network_unit {
registered_name: 'LSTMNetwork'
parameters {
key: "hidden_layer_sizes"
value: "128"
}
}
}
"""
test_spec_linked = """
component {
name: 'bi_lstm'
backend { registered_name: 'TestComponent' }
fixed_feature {
name: 'words'
fml: 'words'
size: 1
embedding_dim: 32
vocabulary_size: 1079813,
}
linked_feature {
name: 'lstm_h'
fml: 'bias(0)'
embedding_dim: -1
size: 1
source_component: 'bi_lstm'
source_translator: 'history'
source_layer: 'lstm_h'
}
linked_feature {
name: 'lstm_c'
fml: 'bias(0)'
embedding_dim: -1
size: 1
source_component: 'bi_lstm'
source_translator: 'history'
source_layer: 'lstm_c'
}
network_unit {
registered_name: 'LSTMNetwork'
parameters {
key: "hidden_layer_sizes"
value: "128"
}
}
}
"""
def setUp(self):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf.reset_default_graph()
def construct_lstm_network_unit(self, master):
"""Helper to construct a LSTMNetwork. Doesn't call create() yet."""
component = MockComponent(master, master.spec.component[0])
with tf.variable_scope('bi_lstm'):
lstm_network_unit = network_units.LSTMNetwork(component)
return lstm_network_unit
def get_context_tensor_arrays(self, lstm_network_unit):
context_tensor_arrays = []
for context_layer in lstm_network_unit.context_layers:
context_tensor_arrays.append(context_layer.create_array(1))
return context_tensor_arrays
def fixed_word_embeddings(self):
"""Helper for returning fixed embeddings, for 1 word feature."""
words_tensor = tf.constant([[1.0] * 32], dtype=tf.float32)
return [network_units.NamedTensor(words_tensor, 'words')]
def testCanCreate(self):
"""Smoke test that the create() function doesn't raise errors."""
master = MockMaster()
master.spec = spec_pb2.MasterSpec()
text_format.Parse(self.test_spec_1, master.spec)
lstm_network_unit = self.construct_lstm_network_unit(master)
with tf.variable_scope('bi_lstm', reuse=True):
lstm_network_unit.create(
self.fixed_word_embeddings(), [],
self.get_context_tensor_arrays(lstm_network_unit), None, True)
def testCanCreateLinked(self):
"""Smoke test that the create() function doesn't raise errors."""
master = MockMaster()
master.spec = spec_pb2.MasterSpec()
text_format.Parse(self.test_spec_linked, master.spec)
lstm_network_unit = self.construct_lstm_network_unit(master)
with tf.variable_scope('bi_lstm', reuse=True):
lstm_network_unit.create(
self.fixed_word_embeddings(), [],
self.get_context_tensor_arrays(lstm_network_unit), None, True)
def testRuntimeConcatentatedMatrices(self):
"""Test generation of concatenated matrices."""
# TODO(googleuser): Make MockComponent support runtime graph generation.
master = MockMaster(build_runtime_graph=False)
master.spec = spec_pb2.MasterSpec()
text_format.Parse(self.test_spec_1, master.spec)
lstm_network_unit = self.construct_lstm_network_unit(master)
with tf.variable_scope('bi_lstm', reuse=True):
lstm_network_unit.create(
self.fixed_word_embeddings(), [],
self.get_context_tensor_arrays(lstm_network_unit), None, False)
x_to_ico = lstm_network_unit.derived_params[0]()
h_to_ico = lstm_network_unit.derived_params[1]()
ico_bias = lstm_network_unit.derived_params[2]()
# Should be the word dimension (32) to 3x the hidden dimension (128).
self.assertEqual(x_to_ico.shape, (32, 384))
self.assertEqual(x_to_ico.op.name, 'bi_lstm/x_to_ico')
# Should be the hidden dimension (128) to 3x the hidden dimension (128).
self.assertEqual(h_to_ico.shape, (128, 384))
self.assertEqual(h_to_ico.op.name, 'bi_lstm/h_to_ico')
# Should be equal to the hidden dimension (128) times 3.
self.assertEqual(ico_bias.shape, (384,))
self.assertEqual(ico_bias.op.name, 'bi_lstm/ico_bias')
def testRuntimeConcatentatedMatricesLinked(self):
"""Test generation of concatenated matrices."""
# TODO(googleuser): Make MockComponent support runtime graph generation.
master = MockMaster(build_runtime_graph=False)
master.spec = spec_pb2.MasterSpec()
text_format.Parse(self.test_spec_linked, master.spec)
lstm_network_unit = self.construct_lstm_network_unit(master)
with tf.variable_scope('bi_lstm', reuse=True):
lstm_network_unit.create(
self.fixed_word_embeddings(), [],
self.get_context_tensor_arrays(lstm_network_unit), None, False)
x_to_ico = lstm_network_unit.derived_params[0]()
h_to_ico = lstm_network_unit.derived_params[1]()
ico_bias = lstm_network_unit.derived_params[2]()
# Should be the word dimension (32) to 3x the hidden dimension (128).
self.assertEqual(x_to_ico.shape, (32, 384))
# Should be the hidden dimension (128) to 3x the hidden dimension (128).
self.assertEqual(h_to_ico.shape, (128, 384))
# Should be equal to the hidden dimension (128) times 3.
self.assertEqual(ico_bias.shape, (384,))
class GatherNetworkTest(test_util.TensorFlowTestCase):
def setUp(self):
......@@ -214,12 +371,30 @@ class GatherNetworkTest(test_util.TensorFlowTestCase):
network = network_units.GatherNetwork(self._component)
# Construct a batch of two items with 3 and 2 steps, respectively.
indices = tf.constant([[1], [2], [0], # item 1
[-1], [0], [-1]], # item 2
dtype=tf.int64)
features = tf.constant([[1.0, 1.5], [2.0, 2.5], [3.0, 3.5], # item 1
[4.0, 4.5], [5.0, 5.5], [6.0, 6.5]], # item 2
dtype=tf.float32)
indices = tf.constant(
[
# item 1
[1],
[2],
[0],
# item 2
[-1],
[0],
[-1]
],
dtype=tf.int64)
features = tf.constant(
[
# item 1
[1.0, 1.5],
[2.0, 2.5],
[3.0, 3.5],
# item 2
[4.0, 4.5],
[5.0, 5.5],
[6.0, 6.5]
],
dtype=tf.float32)
fixed_embeddings = []
linked_embeddings = [
......@@ -233,13 +408,16 @@ class GatherNetworkTest(test_util.TensorFlowTestCase):
gathered = outputs[0]
# Zeros will be substituted for index -1.
self.assertAllEqual(gathered.eval(),
[[2.0, 2.5], # gathered from 1
[3.0, 3.5], # gathered from 2
[1.0, 1.5], # gathered from 0
[0.0, 0.0], # gathered from -1
[4.0, 4.5], # gathered from 0
[0.0, 0.0]]) # gathered from -1
self.assertAllEqual(
gathered.eval(),
[
[2.0, 2.5], # gathered from 1
[3.0, 3.5], # gathered from 2
[1.0, 1.5], # gathered from 0
[0.0, 0.0], # gathered from -1
[4.0, 4.5], # gathered from 0
[0.0, 0.0] # gathered from -1
])
def testTrainablePadding(self):
self._component.spec.network_unit.parameters['trainable_padding'] = 'true'
......@@ -248,12 +426,30 @@ class GatherNetworkTest(test_util.TensorFlowTestCase):
network = network_units.GatherNetwork(self._component)
# Construct a batch of two items with 3 and 2 steps, respectively.
indices = tf.constant([[1], [2], [0], # item 1
[-1], [0], [-1]], # item 2
dtype=tf.int64)
features = tf.constant([[1.0, 1.5], [2.0, 2.5], [3.0, 3.5], # item 1
[4.0, 4.5], [5.0, 5.5], [6.0, 6.5]], # item 2
dtype=tf.float32)
indices = tf.constant(
[
# item 1
[1],
[2],
[0],
# item 2
[-1],
[0],
[-1]
],
dtype=tf.int64)
features = tf.constant(
[
# item 1
[1.0, 1.5],
[2.0, 2.5],
[3.0, 3.5],
# item 2
[4.0, 4.5],
[5.0, 5.5],
[6.0, 6.5]
],
dtype=tf.float32)
fixed_embeddings = []
linked_embeddings = [
......@@ -299,8 +495,8 @@ class IdentityInitializerTest(test_util.TensorFlowTestCase):
"""
with tf.Graph().as_default(), self.test_session() as session:
np.random.seed(4)
tensor = network_units.add_var_initialized('tensor', shape, 'identity',
divisor=divisor, stddev=std)
tensor = network_units.add_var_initialized(
'tensor', shape, 'identity', divisor=divisor, stddev=std)
session.run(tf.global_variables_initializer())
actual = session.run(tensor)
self.assertAllClose(actual, expected, 1e-8, 1e-8)
......@@ -345,13 +541,13 @@ class IdentityInitializerTest(test_util.TensorFlowTestCase):
divisor = 3.
std = 1e-3
shape = (6, 3)
m = divisor/shape[-1]
expected = [[m, 4.99951362e-04, -9.95908980e-04],
[m, -4.18301526e-04, -1.58457726e-03],
[-6.47706795e-04, m, 3.32250027e-04],
[-1.14747661e-03, m, -8.79869258e-05],
[4.25072387e-04, 3.32253141e-04, m],
[3.50997143e-04, -6.06887275e-04, m]]
m = divisor / shape[-1]
expected = [[m, 4.99951362e-04,
-9.95908980e-04], [m, -4.18301526e-04, -1.58457726e-03],
[-6.47706795e-04, m,
3.32250027e-04], [-1.14747661e-03, m, -8.79869258e-05],
[4.25072387e-04, 3.32253141e-04,
m], [3.50997143e-04, -6.06887275e-04, m]]
self.IdentityInitializerHelper(shape, expected, divisor, std)
def testIdentityInitializerNonSquareRank2FirstDimSmaller(self):
......@@ -368,14 +564,14 @@ class IdentityInitializerTest(test_util.TensorFlowTestCase):
std = 1e-3
shape = (2, 2, 6)
m = divisor / shape[-1]
expected = [[[5.05617063e-05, 4.99951362e-04, -9.95908980e-04,
6.93598529e-04, -4.18301526e-04, -1.58457726e-03],
[-6.47706795e-04, 5.98575163e-04, 3.32250027e-04,
-1.14747661e-03, 6.18669670e-04, -8.79869258e-05]],
[[m, m, m,
3.50997143e-04, -6.06887275e-04, 1.54697930e-03],
[7.23341596e-04, 4.61355667e-05, -9.82991653e-04,
m, m, m]]]
expected = [[[
5.05617063e-05, 4.99951362e-04, -9.95908980e-04, 6.93598529e-04,
-4.18301526e-04, -1.58457726e-03
], [
-6.47706795e-04, 5.98575163e-04, 3.32250027e-04, -1.14747661e-03,
6.18669670e-04, -8.79869258e-05
]], [[m, m, m, 3.50997143e-04, -6.06887275e-04, 1.54697930e-03],
[7.23341596e-04, 4.61355667e-05, -9.82991653e-04, m, m, m]]]
self.IdentityInitializerHelper(shape, expected, divisor, std)
def testIdentityInitializerNonSquareRank4(self):
......@@ -383,40 +579,110 @@ class IdentityInitializerTest(test_util.TensorFlowTestCase):
std = 1e-3
shape = (2, 3, 2, 8)
m = divisor / float(shape[-1])
expected = [
[[[5.05617063e-05, 4.99951362e-04, -9.95908980e-04, 6.93598529e-04,
-4.18301526e-04, -1.58457726e-03, -6.47706795e-04, 5.98575163e-04],
[3.32250027e-04, -1.14747661e-03, 6.18669670e-04, -8.79869258e-05,
4.25072387e-04, 3.32253141e-04, -1.15681626e-03, 3.50997143e-04]],
[[-6.06887275e-04, 1.54697930e-03, 7.23341596e-04, 4.61355667e-05,
-9.82991653e-04, 5.44327377e-05, 1.59892938e-04, -1.20894820e-03],
[2.22336012e-03, 3.94295203e-04, 1.69235771e-03, -1.11281220e-03,
1.63574750e-03, -1.36096554e-03, -6.51225855e-04, 5.42451337e-04]],
[[4.80062481e-05, -2.35807360e-03, -1.10558409e-03, 8.37836356e-04,
2.08787085e-03, 9.14840959e-04, -2.76203355e-04, 7.96511886e-04],
[-1.14379858e-03, 5.09919773e-04, -1.34746032e-03, -9.36010019e-06,
-1.30704633e-04, 8.02086608e-04, -3.02963977e-04, 1.20200263e-03]]],
[[[-1.96745284e-04, 8.36528721e-04, 7.86602264e-04, -1.84087583e-03,
3.75474883e-05, 3.59280530e-05, -7.78739923e-04, 1.79410708e-04],
[-1.45553437e-03, 5.56185201e-04, 5.09778853e-04, 3.00445536e-04,
2.47658417e-03, 3.52343399e-04, 6.74710027e-05, -7.32264714e-04]],
[[m, m, m, m,
1.58469542e-04, 1.99008291e-03, 1.16418756e-03, 2.42660157e-04],
[1.37992005e-03, -5.45587063e-05, 7.95233937e-04, 1.90899627e-05,
m, m, m, m]],
[[-1.09712186e-03, -5.28196048e-04, -2.37977528e-03, -6.07683673e-04,
-1.07529014e-03, 2.02240516e-03, -5.64875314e-04, -1.54292909e-03],
[8.70841788e-04, -1.75210531e-04, 4.86030076e-05, 1.88646198e-04,
2.09313483e-04, -3.74444906e-04, 9.54698597e-04, 5.23247640e-04]]]
]
expected = [[[[
5.05617063e-05, 4.99951362e-04, -9.95908980e-04, 6.93598529e-04,
-4.18301526e-04, -1.58457726e-03, -6.47706795e-04, 5.98575163e-04
], [
3.32250027e-04, -1.14747661e-03, 6.18669670e-04, -8.79869258e-05,
4.25072387e-04, 3.32253141e-04, -1.15681626e-03, 3.50997143e-04
]], [[
-6.06887275e-04, 1.54697930e-03, 7.23341596e-04, 4.61355667e-05,
-9.82991653e-04, 5.44327377e-05, 1.59892938e-04, -1.20894820e-03
], [
2.22336012e-03, 3.94295203e-04, 1.69235771e-03, -1.11281220e-03,
1.63574750e-03, -1.36096554e-03, -6.51225855e-04, 5.42451337e-04
]], [[
4.80062481e-05, -2.35807360e-03, -1.10558409e-03, 8.37836356e-04,
2.08787085e-03, 9.14840959e-04, -2.76203355e-04, 7.96511886e-04
], [
-1.14379858e-03, 5.09919773e-04, -1.34746032e-03, -9.36010019e-06,
-1.30704633e-04, 8.02086608e-04, -3.02963977e-04, 1.20200263e-03
]]], [[[
-1.96745284e-04, 8.36528721e-04, 7.86602264e-04, -1.84087583e-03,
3.75474883e-05, 3.59280530e-05, -7.78739923e-04, 1.79410708e-04
], [
-1.45553437e-03, 5.56185201e-04, 5.09778853e-04, 3.00445536e-04,
2.47658417e-03, 3.52343399e-04, 6.74710027e-05, -7.32264714e-04
]], [[
m, m, m, m, 1.58469542e-04, 1.99008291e-03, 1.16418756e-03,
2.42660157e-04
], [
1.37992005e-03, -5.45587063e-05, 7.95233937e-04, 1.90899627e-05, m, m,
m, m
]], [[
-1.09712186e-03, -5.28196048e-04, -2.37977528e-03, -6.07683673e-04,
-1.07529014e-03, 2.02240516e-03, -5.64875314e-04, -1.54292909e-03
], [
8.70841788e-04, -1.75210531e-04, 4.86030076e-05, 1.88646198e-04,
2.09313483e-04, -3.74444906e-04, 9.54698597e-04, 5.23247640e-04
]]]]
self.IdentityInitializerHelper(shape, expected, divisor, std)
class FeatureIdDropoutTest(test_util.TensorFlowTestCase):
def setUp(self):
# Clear the graph and all existing variables. Otherwise, variables created
# in different tests may collide with each other.
tf.reset_default_graph()
def testApplyFeatureIdDropout(self):
channel = spec_pb2.FixedFeatureChannel()
text_format.Parse("""
vocabulary_size: 10
dropout_id: 8
dropout_keep_probability: [0.0, 0.25, 0.5, 0.75, 1.0]
""", channel)
with tf.Graph().as_default(), self.test_session():
with tf.variable_scope('test_scope'):
ids = tf.constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=tf.int64)
weights = tf.constant([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=tf.float32)
tensors = network_units.apply_feature_id_dropout(ids, weights, channel)
perturbed_ids = tensors[0].eval()
tf.logging.info('perturbed_ids = %s', perturbed_ids)
# Given the dropout_keep_probability values specified above:
# * ID 0 is never kept.
# * IDs 1-3 are randomly kept with varying probability.
# * IDs 4-9 are always kept.
# To avoid non-determinism, we only check for specific feature IDs at
# the extremes (never/always kept). Behavior in between the extremes
# should interpolate between the two extremes.
self.assertEqual(perturbed_ids[0], channel.dropout_id)
self.assertTrue(perturbed_ids[1] in (1, channel.dropout_id))
self.assertTrue(perturbed_ids[2] in (2, channel.dropout_id))
self.assertTrue(perturbed_ids[3] in (3, channel.dropout_id))
self.assertAllEqual(perturbed_ids[4:], [4, 5, 6, 7, 8, 9])
def testApplyFeatureIdDropoutSkip(self):
channel = spec_pb2.FixedFeatureChannel()
text_format.Parse("""
vocabulary_size: 2
dropout_id: 2
dropout_keep_probability: [0.0, 1.0]
""", channel)
with tf.Graph().as_default(), self.test_session():
with tf.variable_scope('test_scope'):
ids = tf.constant([0, 1], dtype=tf.int64)
weights = tf.constant([1, 1], dtype=tf.float32)
tensors = network_units.apply_feature_id_dropout(ids, weights, channel)
perturbed_ids, perturbed_weights = tensors[0].eval(), tensors[1].eval()
tf.logging.info('perturbed_ids = %s', perturbed_ids)
tf.logging.info('perturbed_weights = %s', perturbed_weights)
# Given the dropout_keep_probability values specified above:
# * ID 0 is never kept, its weight is set to 0.
# * IDs 1 are always kept.
# To avoid non-determinism, we only check for specific feature IDs at
# the extremes (never/always kept).
self.assertEqual(perturbed_ids[0], channel.dropout_id)
self.assertEqual(perturbed_weights[0], 0)
self.assertEqual(perturbed_ids[1], 1)
self.assertEqual(perturbed_weights[1], 1)
if __name__ == '__main__':
googletest.main()
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for supporting the DRAGNN runtime from the TF side."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import re
import tensorflow as tf
from dragnn.python import network_units
from syntaxnet.util import check
def add_hooks(component, cell_subgraph_spec):
"""Adds "hook" nodes to the graph, for use by the runtime.
The runtime hook nodes are not on the path to any required output, and will
not be called when running TF-based DRAGNN. As long as the TF graph is not
pruned, however, the DRAGNN runtime can call them.
Runtime hook nodes can perform any TF computation. Possible uses include:
* Applying stable names to existing tensors (e.g., via tf.identity()).
* Converting variable data from a TF-friendly or training-friendly format
into a runtime-friendly format.
NB: There are several restrictions on the context in which this function is
called. In brief, call ComponentBuilderBase._add_runtime_hooks() at the top
of each ComponentBuilderSubclass.build_*() method. In detail, this:
* Must be called in the variable scope of the |component|, so variable
references in component.get_variable() work.
* Must be called, possibly transitively, from one of the |component|'s
build_*() methods, so MasterBuilder.read_from_avg is set properly for
component.get_variable().
* Must not be called from within a tf.while_loop(), or the hook nodes will
not work. In particular, NetworkUnitInterface.create() is called from a
tf.while_loop() in DynamicComponentBuilder.
Args:
component: Component for which to add hooks.
cell_subgraph_spec: CellSubgraphSpec for which to add hooks.
"""
for channel_id, feature_spec in enumerate(component.spec.linked_feature):
if feature_spec.embedding_dim != -1:
_add_hooks_for_linked_embedding_matrix(component, channel_id)
for channel_id, feature_spec in enumerate(component.spec.fixed_feature):
if feature_spec.embedding_dim != -1:
_add_hooks_for_fixed_embedding_matrix(component, channel_id)
for params in component.network.params:
_add_hooks_for_trainable_params(component, params)
for parameter_getter in component.network.derived_params:
_add_hooks_for_derived_parameter(parameter_getter)
_add_hook_node(
tf.constant(cell_subgraph_spec.SerializeToString(), tf.string),
'{}/EXPORT/CellSubgraphSpec'.format(component.name))
def _blocked_and_dtype_transformations(tensor):
"""Yields variants of a tensor, for standard blocking/dtype variants.
Args:
tensor (tf.Tensor): Input tensor.
Yields:
(modified_tensor, suffix) pairs, where `modified_tensor` is a transformed
version of the input, and `suffix` is a string like "/blocked32".
"""
for blocking_level in (32, 48):
blocked = make_padded_blocked_matrix(tensor, blocking_level)
bfloat16_blocked = tf.to_bfloat16(bfloat16_permutation(blocked))
yield blocked, '/blocked{}'.format(blocking_level)
yield bfloat16_blocked, '/blocked{}/bfloat16'.format(blocking_level)
def _add_hooks_for_linked_embedding_matrix(component, channel_id):
"""Adds runtime hooks for a linked embedding matrix.
The computation performed by network_units.pass_through_embedding_matrix() is
equivalent to the following:
for i in range(stride):
if step_idx[i] == -1:
outputs[i,:] = out_of_bounds_vector
else:
outputs[i,:] = tf.matmul(act_block[i,:], weight_matrix)
The implementation uses clever arithmetic to do this in one matmul per batch.
Specifically, the weight_matrix is extended with the out_of_bounds_vector and
each activation vector is extended with a 0/1 out-of-bounds indicator. Then,
multiplying the two suffices, assuming that act_block[i,:] is set to zero for
out-of-bounds links.
While this works well for training and high-throughput batched computation, it
isn't the best for the runtime:
* Appending a 0/1 indicator to the input activation vector requires a copy.
Ideally, we could use the input activation vector by reference alone.
* In order to access to the |out_of_bounds_vector| as a contiguous array,
the runtime must load the linked embedding matrix in row-major format,
which may not be the fastest format for arithmetic.
* The dimensions of the extended-by-1 matrix and vector are likely to be
pessimal. Most dimensions are specified as 2^n, and adding one element
produces maximal padding on the trailing elements, which in turn wastes
memory, reduces cache utilization, etc.
Therefore, in the runtime we split the linked embedding matrix into a separate
weight matrix and out-of-bounds vector.
Args:
component: Component for which to add hooks.
channel_id: Linked embedding channel for which to add hooks.
"""
var_name = network_units.linked_embeddings_name(channel_id)
extended_matrix = component.get_variable(var_name)
extended_num_rows = tf.shape(extended_matrix)[0]
matrix, vector = tf.split(extended_matrix, [extended_num_rows - 1, 1], 0)
transposed = tf.transpose(matrix)
hook_name = functools.partial(_get_hook_name, component, var_name)
_add_hook_node(matrix, hook_name('/weights'))
_add_hook_node(transposed, hook_name('/weights/transposed'))
# Add blocked versions of the matrix and its transpose.
for blocked, blocked_suffix in _blocked_and_dtype_transformations(matrix):
blocked_name = hook_name('/weights/matrix' + blocked_suffix)
_add_hook_node(blocked, blocked_name)
for blocked, blocked_suffix in _blocked_and_dtype_transformations(transposed):
blocked_name = hook_name('/weights/transposed' + blocked_suffix)
_add_hook_node(blocked, blocked_name)
# Add shape and out-of-bounds information.
_add_hook_node(tf.shape(transposed), hook_name('/weights/transposed/shape'))
_add_hook_node(vector, _get_hook_name(component, var_name, '/out_of_bounds'))
def _add_hooks_for_fixed_embedding_matrix(component, channel_id):
"""Adds runtime hooks for a fixed embedding matrix.
The hooks remove the last row from the embedding matrix. The extra row was
probably intended for out-of-vocabulary items, but those are handled in the
feature system and the extra row is never used.
Args:
component: Component for which to add hooks.
channel_id: Fixed embedding channel for which to add hooks.
"""
var_name = network_units.fixed_embeddings_name(channel_id)
extended_matrix = component.get_variable(var_name)
extended_num_rows = tf.shape(extended_matrix)[0]
matrix = tf.slice(extended_matrix, [0, 0], [extended_num_rows - 1, -1])
# TODO(googleuser): If the extra row is removed from the variable itself, remove
# the tf.slice() and point the hook directly at the variable.
_add_hook_node(matrix, _get_hook_name(component, var_name, '/trimmed'))
def _add_hooks_for_derived_parameter(getter):
"""Adds hooks for derived parameters.
Derived parameters are typically slight format modifications of regular
parameters, exposed because doing the computation in Python is more convenient
than as VariableStore wrappers.
Args:
getter: Function which, when called, will return the derived tensor.
"""
parameter = getter()
full_name = parameter.op.name
def _hook_name(base_name):
"""Returns a hook node name constructed from a base name."""
return full_name + base_name
if parameter.shape.ndims != 2:
tf.logging.info('Not adding matrix hooks for derived parameter %s',
full_name)
return
_add_hook_node(tf.transpose(parameter), _hook_name('/transposed'))
for blocked, blocked_suffix in _blocked_and_dtype_transformations(parameter):
_add_hook_node(blocked, _hook_name('/matrix' + blocked_suffix))
def _add_hooks_for_trainable_params(component, params):
"""Adds runtime hooks for a variable of trainable parameters.
Ignores parameters that are not statically-deducible as matrices.
Args:
component: Component for which to add hooks.
params: Variable for which to add hooks.
"""
full_name = params.op.name
matrix = component.get_variable(var_params=params)
# Only add hooks for tensors that are statically-deducible as matrices.
if params.shape.ndims != 2:
tf.logging.info('Not adding hooks for trainable params %s', full_name)
return
# Infer the suffix to append to variable names, if any, based on whether the
# possibly-averaged |matrix| is named differently than the |params|.
suffix = re.sub('^' + re.escape(full_name), '', matrix.op.name)
check.Ne(suffix, matrix.op.name,
'Failed to find suffix for params %s' % full_name)
def _hook_name(base_name):
"""Returns a hook node name constructed from a base name."""
return full_name + base_name + suffix
# Add the matrix and its transpose.
transposed = tf.transpose(matrix)
_add_hook_node(matrix, _hook_name('/matrix'))
_add_hook_node(transposed, _hook_name('/transposed'))
# Add blocked versions of the matrix and its transpose.
for blocked, blocked_suffix in _blocked_and_dtype_transformations(matrix):
_add_hook_node(blocked, _hook_name('/matrix' + blocked_suffix))
for blocked, blocked_suffix in _blocked_and_dtype_transformations(transposed):
_add_hook_node(blocked, _hook_name('/transposed' + blocked_suffix))
# Also add hooks for the original shapes, which are obscured by padding.
_add_hook_node(tf.shape(matrix), _hook_name('/matrix/shape'))
_add_hook_node(tf.shape(transposed), _hook_name('/transposed/shape'))
def make_padded_blocked_matrix(matrix, block_size):
"""Converts a matrix to padded column-blocked format.
For example, given a [64,127] matrix and block_size=16, this function returns
an [8,64,16] tensor where the 8 inner sub-matrices, when concatenated left to
right, re-constitute the original matrix. Note that the 8th sub-matrix has a
final column of padding.
Args:
matrix: The matrix to convert.
block_size: The number of columns per block.
Returns:
Padded column-blocked matrix.
"""
shape = tf.shape(matrix)
num_rows = shape[0]
num_columns = shape[1]
# Compute the amount of padding and resulting number of blocks.
last_block_size = num_columns % block_size
padding_size = (block_size - last_block_size) % block_size
num_blocks = (num_columns + padding_size) // block_size
# Somehow the obvious approach based on tf.split() and tf.stack() doesn't work
# (seems that the number of splits needs to be statically-known), but this
# alternative based on tf.transpose() and tf.reshape() does. Continuing the
# example from the docstring...
padded = tf.pad(matrix, [[0, 0], [0, padding_size]]) # [64,127] => [64,128]
transposed = tf.transpose(padded) # => [128,64]
blocked = tf.reshape(transposed, [num_blocks, block_size,
num_rows]) # => [8,16,64]
return tf.transpose(blocked, [0, 2, 1]) # => [8,64,16]
def bfloat16_permutation(tensor):
"""Permutes values in the last dimension of a tensor.
This permutation is used so that we can directly use unpacklo/unpackhi AVX2
instructions on the matrix coefficients. These unpacking instructions
effectively permute the data. See FastUnpackPermutation() and
AvxFloatVecArray::Load(const TruncatedFloat16 *) in avx_vector_array.h for
more details.
Args:
tensor: Blocked matrix, the result of make_padded_blocked_matrix(). Must
have its last dimension a multiple of 16.
Returns:
Permuted matrix, suitable for calling tf.to_bfloat16() on. For testing
convenience we don't do so in this method.
Raises:
ValueError: If the matrix's block dimension is not a multiple of 16.
"""
orig_shape = tensor.shape
if tensor.shape[-1] % 16 != 0:
raise ValueError('Bad block dimension, must be divisible by 16')
permutation = [0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15]
indices = tf.constant(
[16 * (i // 16) + permutation[i % 16] for i in xrange(orig_shape[-1])])
return tf.gather(tensor, indices, axis=len(orig_shape) - 1)
def _get_hook_name(component, variable_name, suffix):
"""Builds the name of a hook node.
Specifically, the name of the hook node is:
<component.name>/<variable_name><suffix><remainder>
where <remainder> is whatever follows <variable_name> in the name of the op
that produces the named variable. Recall that component.get_variable() may
return either the original variable or its moving average. These might have
names like:
foo_component/bar_variable
foo_component/bar_variable/ExponentialMovingAverage
In the examples above, the <remainder> is "" for the original variable and
"/ExponentialMovingAverage" for its moving average. Calling this function
with suffix="/baz_suffix" in either case would add hook nodes named:
foo_component/bar_variable/baz_suffix
foo_component/bar_variable/baz_suffix/ExponentialMovingAverage
Note that the suffix is inserted after the variable name, not necessarily at
the end of the entire op name.
Args:
component: Component that the hook node belongs to.
variable_name: Variable that the hook node name is based on.
suffix: Suffix to append to the variable name.
Returns:
Name of the hook node.
"""
variable = component.get_variable(variable_name)
full_name = variable.op.name
prefix = component.name + '/' + variable_name
hook_name = re.sub('^' + re.escape(prefix), prefix + suffix, full_name)
# If re.sub() did not match anything, it returns the unmodified input (i.e.,
# |full_name|). Enforce that some change was made.
check.Ne(
full_name, hook_name,
'Failed to match expected variable prefix "{}" in variable "{}"'.format(
prefix, full_name))
return hook_name
def _add_hook_node(tensor, fully_qualified_name):
"""Adds a hook node that outputs a tensor with a fully-qualified name."""
# Since the name is fully-qualified, insert the hook node into the top-level
# name scope.
with tf.name_scope(None):
tf.identity(tensor, name=fully_qualified_name)
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the runtime support utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from dragnn.protos import export_pb2
from dragnn.protos import spec_pb2
from dragnn.python import network_units
from dragnn.python import runtime_support
class MockNetwork(object):
"""Mock for tests."""
def __init__(self):
self.params = [
tf.get_variable('rank2', [64, 127], tf.float32),
tf.get_variable('rank3', [64, 127, 250], tf.float32)
]
self.derived_params = [
self._fake_derived_vector, self._fake_derived_parameter
]
def _fake_derived_vector(self):
value = tf.constant([1, 2, 3], dtype=tf.float32)
with tf.name_scope(None):
return tf.identity(value, name='derived/vector')
def _fake_derived_parameter(self):
# Use absolute scoping to put the derived parameter in the same namespace.
base_name = self.params[0].op.name.rsplit('/', 1)[0]
with tf.name_scope(None):
return tf.concat(
[self.params[0], self.params[0]],
axis=0,
name='{}/derived'.format(base_name))
class MockComponent(object):
"""Mock for tests."""
def __init__(self):
self.name = 'test_component'
self.spec = spec_pb2.ComponentSpec()
with tf.variable_scope(self.name):
self.network = MockNetwork()
def get_variable(self, var_name=None, var_params=None):
if var_name:
return tf.get_variable(var_name)
else:
return var_params
class RuntimeSupportTest(tf.test.TestCase):
"""Testing rig."""
def testAddLinkedHooks(self):
component = MockComponent()
link0 = component.spec.linked_feature.add()
link1 = component.spec.linked_feature.add()
link0.embedding_dim = -1 # direct link
link1.embedding_dim = 32 # transformed link
link0_matrix_name = network_units.linked_embeddings_name(0)
link1_matrix_name = network_units.linked_embeddings_name(1)
with self.test_session() as session:
graph = session.graph
# Create linked embedding matrices. Only channel 1 uses one.
with tf.variable_scope(component.name):
tf.get_variable(link1_matrix_name, shape=[64 + 1, 32], dtype=tf.float32)
# Add hooks. This should ignore channel 0 and add hooks for channel 1.
with tf.variable_scope(component.name, reuse=True):
runtime_support.add_hooks(component, export_pb2.CellSubgraphSpec())
# Check that no hooks were added for channel 0.
with self.assertRaises(KeyError):
graph.get_tensor_by_name(
'{}/{}/weights:0'.format(component.name, link0_matrix_name))
with self.assertRaises(KeyError):
graph.get_tensor_by_name('{}/{}/weights/transposed:0'.format(
component.name, link0_matrix_name))
with self.assertRaises(KeyError):
graph.get_tensor_by_name('{}/{}/weights/transposed/shape:0'.format(
component.name, link0_matrix_name))
with self.assertRaises(KeyError):
graph.get_tensor_by_name('{}/{}/weights/transposed/blocked32:0'.format(
component.name, link0_matrix_name))
with self.assertRaises(KeyError):
graph.get_tensor_by_name('{}/{}/weights/transposed/blocked48:0'.format(
component.name, link0_matrix_name))
with self.assertRaises(KeyError):
graph.get_tensor_by_name(
'{}/{}/out_of_bounds:0'.format(component.name, link0_matrix_name))
# Get the hooks added for channel 1.
weights = graph.get_tensor_by_name(
'{}/{}/weights:0'.format(component.name, link1_matrix_name))
transposed = graph.get_tensor_by_name('{}/{}/weights/transposed:0'.format(
component.name, link1_matrix_name))
transposed_shape = graph.get_tensor_by_name(
'{}/{}/weights/transposed/shape:0'.format(component.name,
link1_matrix_name))
transposed32 = graph.get_tensor_by_name(
'{}/{}/weights/transposed/blocked32:0'.format(component.name,
link1_matrix_name))
transposed48 = graph.get_tensor_by_name(
'{}/{}/weights/transposed/blocked48:0'.format(component.name,
link1_matrix_name))
out_of_bounds = graph.get_tensor_by_name(
'{}/{}/out_of_bounds:0'.format(component.name, link1_matrix_name))
# Check dimensions of the hooks.
tf.global_variables_initializer().run()
self.assertAllEqual(tf.shape(weights).eval(), [64, 32])
self.assertAllEqual(tf.shape(transposed).eval(), [32, 64])
self.assertAllEqual(transposed_shape.eval(), [32, 64])
self.assertAllEqual(tf.shape(transposed32).eval(), [2, 32, 32])
self.assertAllEqual(tf.shape(transposed48).eval(), [2, 32, 48])
self.assertAllEqual(tf.shape(out_of_bounds).eval(), [1, 32])
def testAddFixedHooks(self):
component = MockComponent()
fixed0 = component.spec.fixed_feature.add()
fixed1 = component.spec.fixed_feature.add()
fixed0.embedding_dim = -1
fixed1.embedding_dim = 32
fixed0.vocabulary_size = 100
fixed1.vocabulary_size = 1000
fixed0_matrix_name = network_units.fixed_embeddings_name(0)
fixed1_matrix_name = network_units.fixed_embeddings_name(1)
with self.test_session() as session:
graph = session.graph
# Create fixed embedding matrices. Only channel 1 uses one.
with tf.variable_scope(component.name):
tf.get_variable(
fixed1_matrix_name, shape=[1000 + 1, 32], dtype=tf.float32)
# Add hooks. This should ignore channel 0 and add hooks for channel 1.
with tf.variable_scope(component.name, reuse=True):
runtime_support.add_hooks(component, export_pb2.CellSubgraphSpec())
# Check that no hooks were added for channel 0.
with self.assertRaises(KeyError):
graph.get_tensor_by_name(
'{}/{}/trimmed:0'.format(component.name, fixed0_matrix_name))
# Get the hooks added for channel 1.
trimmed = graph.get_tensor_by_name(
'{}/{}/trimmed:0'.format(component.name, fixed1_matrix_name))
# Check dimensions of the hooks.
tf.global_variables_initializer().run()
self.assertAllEqual(tf.shape(trimmed).eval(), [1000, 32])
def testAddParamsHooks(self):
component = MockComponent()
rank2_name = 'rank2'
rank3_name = 'rank3'
with self.test_session() as session:
graph = session.graph
# Add hooks. This should add hooks for all rank-2 params.
with tf.variable_scope(component.name, reuse=True):
runtime_support.add_hooks(component, export_pb2.CellSubgraphSpec())
# Check that no hooks were added for the rank-3 params.
with self.assertRaises(KeyError):
graph.get_tensor_by_name(
'{}/{}/matrix:0'.format(component.name, rank3_name))
with self.assertRaises(KeyError):
graph.get_tensor_by_name(
'{}/{}/transposed:0'.format(component.name, rank3_name))
with self.assertRaises(KeyError):
graph.get_tensor_by_name(
'{}/{}/matrix/blocked32:0'.format(component.name, rank3_name))
with self.assertRaises(KeyError):
graph.get_tensor_by_name(
'{}/{}/matrix/blocked48:0'.format(component.name, rank3_name))
with self.assertRaises(KeyError):
graph.get_tensor_by_name(
'{}/{}/transposed/blocked32:0'.format(component.name, rank3_name))
with self.assertRaises(KeyError):
graph.get_tensor_by_name(
'{}/{}/transposed/blocked48:0'.format(component.name, rank3_name))
with self.assertRaises(KeyError):
graph.get_tensor_by_name(
'{}/{}/matrix/shape:0'.format(component.name, rank3_name))
with self.assertRaises(KeyError):
graph.get_tensor_by_name(
'{}/{}/transposed/shape:0'.format(component.name, rank3_name))
# Get the hooks added for each variable.
matrix = graph.get_tensor_by_name(
'{}/{}/matrix:0'.format(component.name, rank2_name))
transposed = graph.get_tensor_by_name(
'{}/{}/transposed:0'.format(component.name, rank2_name))
matrix32 = graph.get_tensor_by_name(
'{}/{}/matrix/blocked32:0'.format(component.name, rank2_name))
matrix48 = graph.get_tensor_by_name(
'{}/{}/matrix/blocked48:0'.format(component.name, rank2_name))
transposed32 = graph.get_tensor_by_name(
'{}/{}/transposed/blocked32:0'.format(component.name, rank2_name))
transposed48 = graph.get_tensor_by_name(
'{}/{}/transposed/blocked48:0'.format(component.name, rank2_name))
matrix_shape = graph.get_tensor_by_name(
'{}/{}/matrix/shape:0'.format(component.name, rank2_name))
transposed_shape = graph.get_tensor_by_name(
'{}/{}/transposed/shape:0'.format(component.name, rank2_name))
# Check dimensions of the hooks.
tf.global_variables_initializer().run()
self.assertAllEqual(tf.shape(matrix).eval(), [64, 127])
self.assertAllEqual(tf.shape(transposed).eval(), [127, 64])
self.assertAllEqual(matrix_shape.eval(), [64, 127])
self.assertAllEqual(transposed_shape.eval(), [127, 64])
self.assertAllEqual(tf.shape(matrix32).eval(), [4, 64, 32])
self.assertAllEqual(tf.shape(matrix48).eval(), [3, 64, 48])
self.assertAllEqual(tf.shape(transposed32).eval(), [2, 127, 32])
self.assertAllEqual(tf.shape(transposed48).eval(), [2, 127, 48])
def testAddDerivedParamHooks(self):
component = MockComponent()
derived_name = 'derived'
with self.test_session() as session:
graph = session.graph
# Add hooks.
with tf.variable_scope(component.name, reuse=True):
runtime_support.add_hooks(component, export_pb2.CellSubgraphSpec())
session.run(tf.global_variables_initializer())
# Get hooks for the derived vector.
vector = graph.get_tensor_by_name('derived/vector:0')
self.assertEqual(vector.shape, (3,))
# Get the hooks for the derived variable.
matrix = graph.get_tensor_by_name(
'{}/{}/matrix/blocked32:0'.format(component.name, derived_name))
self.assertAllEqual(tf.shape(matrix).eval(), [4, 128, 32])
# Check the bfloat16 version. It should have the same shape.
bfloat16_matrix = graph.get_tensor_by_name(
'{}/{}/matrix/blocked32/bfloat16:0'.format(component.name,
derived_name))
self.assertAllEqual(tf.shape(bfloat16_matrix).eval(), [4, 128, 32])
def testMakePaddedBlockedMatrix(self):
with self.test_session():
matrix = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15],
[16, 17, 18, 19, 20]]
expected_blocked = [[[1, 2], [6, 7], [11, 12],
[16, 17]], [[3, 4], [8, 9], [13, 14], [18, 19]],
[[5, 0], [10, 0], [15, 0], [20, 0]]]
matrix = tf.constant(matrix, tf.float32)
actual_blocked = runtime_support.make_padded_blocked_matrix(matrix, 2)
self.assertAllEqual(actual_blocked.eval(), expected_blocked)
def testBfloat16Permutation(self):
with self.test_session():
matrix = [list(range(16))]
expected_permuted = [[
0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
]]
matrix = tf.constant(matrix, tf.float32)
actual_permuted = runtime_support.bfloat16_permutation(matrix)
self.assertAllEqual(actual_permuted.eval(), expected_permuted)
def testLargerBfloat16Permutation(self):
with self.test_session() as session:
matrix = tf.random_uniform((3, 4, 32))
permuted = runtime_support.bfloat16_permutation(matrix)
matrix, actual_permuted = session.run([matrix, permuted])
# Just check a few items for now, hopefully that's sufficient to ensure
# the permutation is okay.
self.assertEqual(matrix[0, 0, 0], actual_permuted[0, 0, 0])
self.assertEqual(matrix[0, 0, 1], actual_permuted[0, 0, 1])
self.assertEqual(matrix[1, 1, 16], actual_permuted[1, 1, 16])
self.assertEqual(matrix[2, 0, 4], actual_permuted[2, 0, 8])
self.assertEqual(matrix[2, 0, 5], actual_permuted[2, 0, 9])
self.assertEqual(matrix[2, 1, 8], actual_permuted[2, 1, 4])
self.assertEqual(matrix[2, 1, 8 + 16], actual_permuted[2, 1, 4 + 16])
def testAddCellSubgraphSpecHook(self):
component = MockComponent()
cell = export_pb2.CellSubgraphSpec()
cell.input.add(
name='feature',
tensor='feature_tensor',
type=export_pb2.CellSubgraphSpec.Input.TYPE_FEATURE)
cell.input.add(
name='recurrent',
tensor='recurrent_tensor',
type=export_pb2.CellSubgraphSpec.Input.TYPE_RECURRENT)
cell.output.add(name='layer_0', tensor='layer_0_tensor')
cell.output.add(name='logits', tensor='logits_tensor')
with self.test_session() as session:
graph = session.graph
# Add hooks for the cell constructed above.
with tf.variable_scope(component.name, reuse=True):
runtime_support.add_hooks(component, cell)
# Get the hook containing the wire-format proto.
cell_wire_format = graph.get_tensor_by_name(
'{}/EXPORT/CellSubgraphSpec:0'.format(component.name))
# Check that the hook matches the cell.
tf.global_variables_initializer().run()
self.assertEqual(cell_wire_format.eval(), cell.SerializeToString())
if __name__ == '__main__':
tf.test.main()
......@@ -16,30 +16,19 @@
import os
import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from dragnn.python import dragnn_ops
from dragnn.python import sentence_io
from syntaxnet import sentence_pb2
FLAGS = tf.app.flags.FLAGS
def setUpModule():
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
from syntaxnet import test_flags
class ConllSentenceReaderTest(test_util.TensorFlowTestCase):
class ConllSentenceReaderTest(tf.test.TestCase):
def setUp(self):
# This dataset contains 54 sentences.
self.filepath = os.path.join(
FLAGS.test_srcdir,
test_flags.source_root(),
'syntaxnet/testdata/mini-training-set')
self.batch_size = 20
......@@ -82,4 +71,4 @@ class ConllSentenceReaderTest(test_util.TensorFlowTestCase):
if __name__ == '__main__':
googletest.main()
tf.test.main()
......@@ -15,7 +15,6 @@
"""Utils for building DRAGNN specs."""
from six.moves import xrange
import tensorflow as tf
from dragnn.protos import spec_pb2
......@@ -110,7 +109,9 @@ class ComponentSpecBuilder(object):
if transition_spec.registered_name == 'arc-standard':
return 'shift-reduce-step'
if transition_spec.registered_name in ('shift-only', 'tagger'):
if transition_spec.registered_name in ('shift-only', 'tagger', 'morpher',
'lm-transitions', 'dependency-label',
'category'):
if 'left_to_right' in transition_spec.parameters:
if transition_spec.parameters['left_to_right'] == 'false':
return 'reverse-token'
......
......@@ -27,15 +27,6 @@ from dragnn.python import spec_builder
from syntaxnet import parser_trainer
FLAGS = tf.app.flags.FLAGS
def setUpModule():
if not hasattr(FLAGS, 'test_srcdir'):
FLAGS.test_srcdir = ''
if not hasattr(FLAGS, 'test_tmpdir'):
FLAGS.test_tmpdir = tf.test.get_temp_dir()
class SpecBuilderTest(tf.test.TestCase):
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions to build DRAGNN MasterSpecs and schedule model training.
Provides functions to finish a MasterSpec, building required lexicons for it and
......@@ -23,13 +22,12 @@ import random
import tensorflow as tf
from six.moves import xrange
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.python.framework import errors
from tensorflow.python.platform import gfile
flags = tf.app.flags
FLAGS = flags.FLAGS
from syntaxnet.util import check
def calculate_component_accuracies(eval_res_values):
......@@ -59,7 +57,9 @@ def annotate_dataset(sess, annotator, eval_corpus):
end = min(start + batch_size, len(eval_corpus))
serialized_annotations = sess.run(
annotator['annotations'],
feed_dict={annotator['input_batch']: eval_corpus[start:end]})
feed_dict={
annotator['input_batch']: eval_corpus[start:end]
})
assert len(serialized_annotations) == end - start
processed.extend(serialized_annotations)
tf.logging.info('Done. Produced %d annotations', len(processed))
......@@ -81,16 +81,60 @@ def get_summary_writer(tensorboard_dir):
return summary_writer
def generate_target_per_step_schedule(pretrain_steps, train_steps):
"""Generates a sampled training schedule.
Arguments:
pretrain_steps: List, number of pre-training steps per each target.
train_steps: List, number of sampled training steps per each target.
Returns:
Python list of length sum(pretrain_steps + train_steps), containing
target numbers per step.
"""
check.Eq(len(pretrain_steps), len(train_steps))
# Arbitrary seed to make sure the return is deterministic.
random.seed(0x31337)
tf.logging.info('Determining the training schedule...')
target_per_step = []
for target_idx in xrange(len(pretrain_steps)):
target_per_step += [target_idx] * pretrain_steps[target_idx]
train_steps = list(train_steps)
while sum(train_steps) > 0:
step = random.randint(0, sum(train_steps) - 1)
cumulative_steps = 0
for target_idx in xrange(len(train_steps)):
cumulative_steps += train_steps[target_idx]
if step < cumulative_steps:
break
assert train_steps[target_idx] > 0
train_steps[target_idx] -= 1
target_per_step.append(target_idx)
tf.logging.info('Training schedule defined!')
return target_per_step
def run_training_step(sess, trainer, train_corpus, batch_size):
"""Runs a single iteration of train_op on a randomly sampled batch."""
batch = random.sample(train_corpus, batch_size)
sess.run(trainer['run'], feed_dict={trainer['input_batch']: batch})
def run_training(sess, trainers, annotator, evaluator, pretrain_steps,
train_steps, train_corpus, eval_corpus, eval_gold,
batch_size, summary_writer, report_every, saver,
checkpoint_filename, checkpoint_stats=None):
def run_training(sess,
trainers,
annotator,
evaluator,
pretrain_steps,
train_steps,
train_corpus,
eval_corpus,
eval_gold,
batch_size,
summary_writer,
report_every,
saver,
checkpoint_filename,
checkpoint_stats=None):
"""Runs multi-task DRAGNN training on a single corpus.
Arguments:
......@@ -117,30 +161,15 @@ def run_training(sess, trainers, annotator, evaluator, pretrain_steps,
checkpoint_filename: File to save checkpoints to.
checkpoint_stats: Stats of checkpoint.
"""
random.seed(0x31337)
if not checkpoint_stats:
checkpoint_stats = [0] * (len(train_steps) + 1)
tf.logging.info('Determining the training schedule...')
target_for_step = []
for target_idx in xrange(len(pretrain_steps)):
target_for_step += [target_idx] * pretrain_steps[target_idx]
while sum(train_steps) > 0:
step = random.randint(0, sum(train_steps) - 1)
cumulative_steps = 0
for target_idx in xrange(len(train_steps)):
cumulative_steps += train_steps[target_idx]
if step < cumulative_steps:
break
assert train_steps[target_idx] > 0
train_steps[target_idx] -= 1
target_for_step.append(target_idx)
tf.logging.info('Training schedule defined!')
target_per_step = generate_target_per_step_schedule(pretrain_steps,
train_steps)
best_eval_metric = -1.0
tf.logging.info('Starting training...')
actual_step = sum(checkpoint_stats[1:])
for step, target_idx in enumerate(target_for_step):
for step, target_idx in enumerate(target_per_step):
run_training_step(sess, trainers[target_idx], train_corpus, batch_size)
checkpoint_stats[target_idx + 1] += 1
if step % 100 == 0:
......
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for dragnn.python.trainer_lib."""
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
from dragnn.python import trainer_lib
class TrainerLibTest(test_util.TensorFlowTestCase):
def testImmutabilityOfArguments(self):
"""Tests that training schedule generation does not change its arguments."""
pretrain_steps = [1, 2, 3]
train_steps = [5, 5, 5]
trainer_lib.generate_target_per_step_schedule(pretrain_steps, train_steps)
self.assertEqual(pretrain_steps, [1, 2, 3])
self.assertEqual(train_steps, [5, 5, 5])
def testTrainingScheduleGenerationAndDeterminism(self):
"""Non-trivial schedule, check generation and determinism."""
pretrain_steps = [1, 2, 3]
train_steps = [5, 5, 5]
generated_schedule = trainer_lib.generate_target_per_step_schedule(
pretrain_steps, train_steps)
expected_schedule = [
0, 1, 1, 2, 2, 2, 1, 0, 2, 1, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2
]
self.assertEqual(generated_schedule, expected_schedule)
def testNoPretrainSteps(self):
"""Edge case, 1 target, no pretrain."""
generated_schedule = trainer_lib.generate_target_per_step_schedule([0],
[10])
expected_schedule = [0] * 10
self.assertEqual(generated_schedule, expected_schedule)
def testNoTrainSteps(self):
"""Edge case, 1 target, only pretrain."""
generated_schedule = trainer_lib.generate_target_per_step_schedule([10],
[0])
expected_schedule = [0] * 10
self.assertEqual(generated_schedule, expected_schedule)
if __name__ == '__main__':
googletest.main()
......@@ -330,7 +330,7 @@ class LayerNormBasicLSTMNetwork(BaseLSTMNetwork):
def _cell_closure(scope):
"""Applies the LSTM cell to the current inputs and state."""
return cell(input_tensor, state, scope)
return cell(input_tensor, state, scope=scope)
unused_h, state = self._apply_with_captured_variables(_cell_closure)
......
package(
default_visibility = ["//visibility:public"],
)
load(
"@org_tensorflow//tensorflow:tensorflow.bzl",
"if_linux_x86_64",
)
load(
"//dragnn/runtime:multiarch.bzl",
"dragnn_cc_multiarch_binary",
"dragnn_cc_multiarch_library",
"dragnn_cc_multiarch_test",
)
FAST_MATH_COPTS = if_linux_x86_64([
"-O3",
"-msse4.2",
"-ffast-math",
"-ftree-vectorize",
])
filegroup(
name = "test_rnn_tagger",
srcs = glob(["testdata/rnn_tagger/**"]),
)
cc_library(
name = "alignment",
hdrs = ["alignment.h"],
deps = [
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "alignment_test",
size = "small",
srcs = ["alignment_test.cc"],
deps = [
":alignment",
"//dragnn/core/test:generic",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "mmap",
srcs = ["mmap.cc"],
hdrs = ["mmap.h"],
deps = [
":alignment",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "mmap_test",
size = "small",
srcs = ["mmap_test.cc"],
data = [
"testdata/empty_file",
"testdata/ten_bytes",
],
deps = [
":mmap",
"//dragnn/core/test:generic",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "operands",
srcs = ["operands.cc"],
hdrs = ["operands.h"],
deps = [
":alignment",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "operands_test",
size = "small",
srcs = ["operands_test.cc"],
deps = [
":alignment",
":operands",
"//dragnn/runtime/math:types",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "variable_store",
hdrs = ["variable_store.h"],
deps = [
":alignment",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "variable_store_test",
size = "small",
srcs = ["variable_store_test.cc"],
deps = [
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/test:fake_variable_store",
"//dragnn/runtime/test:helpers",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "trained_model",
srcs = ["trained_model.cc"],
hdrs = ["trained_model.h"],
deps = [
"//dragnn/core:dragnn_bulk_ops_cc",
"//dragnn/core:dragnn_ops_cc",
"//syntaxnet:base",
"//syntaxnet:parser_ops_cc",
"@org_tensorflow//tensorflow/cc/saved_model:loader",
"@org_tensorflow//tensorflow/cc/saved_model:tag_constants",
"@org_tensorflow//tensorflow/core:core_cpu",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],
)
cc_test(
name = "trained_model_test",
size = "small",
timeout = "moderate",
srcs = ["trained_model_test.cc"],
data = [":test_rnn_tagger"],
deps = [
":trained_model",
"//dragnn/components/syntaxnet:syntaxnet_component",
"//dragnn/core/test:generic",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "trained_model_variable_store",
srcs = ["trained_model_variable_store.cc"],
hdrs = ["trained_model_variable_store.h"],
deps = [
":alignment",
":trained_model",
":variable_store",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:core_cpu",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:tensorflow",
],
)
cc_test(
name = "trained_model_variable_store_test",
size = "small",
timeout = "moderate",
srcs = ["trained_model_variable_store_test.cc"],
data = [":test_rnn_tagger"],
shard_count = 13,
deps = [
":trained_model_variable_store",
"//dragnn/components/syntaxnet:syntaxnet_component",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/math:avx_vector_array",
"//dragnn/runtime/math:float16_types",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "variable_store_wrappers",
srcs = ["variable_store_wrappers.cc"],
hdrs = ["variable_store_wrappers.h"],
deps = [
":alignment",
":flexible_matrix_kernel",
":variable_store",
"//dragnn/protos:runtime_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "variable_store_wrappers_test",
size = "small",
srcs = ["variable_store_wrappers_test.cc"],
deps = [
":flexible_matrix_kernel",
":variable_store_wrappers",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/math:transformations",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:fake_variable_store",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "array_variable_store",
srcs = ["array_variable_store.cc"],
hdrs = ["array_variable_store.h"],
deps = [
":alignment",
":variable_store",
"//dragnn/protos:runtime_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "array_variable_store_test",
size = "small",
srcs = ["array_variable_store_test.cc"],
data = [
"testdata/array_variable_store_data",
"testdata/array_variable_store_spec",
"testdata/empty_file",
],
deps = [
":alignment",
":array_variable_store",
":file_array_variable_store",
":mmap_array_variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:helpers",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "array_variable_store_builder",
srcs = ["array_variable_store_builder.cc"],
hdrs = ["array_variable_store_builder.h"],
deps = [
":alignment",
":array_variable_store",
":variable_store_wrappers",
"//dragnn/protos:runtime_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "array_variable_store_builder_test",
size = "small",
srcs = ["array_variable_store_builder_test.cc"],
data = [
"testdata/array_variable_store_data",
"testdata/array_variable_store_spec",
],
deps = [
":alignment",
":array_variable_store_builder",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/test:helpers",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
# Tested in array_variable_store_test.
cc_library(
name = "file_array_variable_store",
srcs = ["file_array_variable_store.cc"],
hdrs = ["file_array_variable_store.h"],
deps = [
":alignment",
":array_variable_store",
"//dragnn/protos:runtime_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
# Tested in array_variable_store_test.
cc_library(
name = "mmap_array_variable_store",
srcs = ["mmap_array_variable_store.cc"],
hdrs = ["mmap_array_variable_store.h"],
deps = [
":array_variable_store",
":mmap",
"//dragnn/protos:runtime_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_library(
name = "network_states",
srcs = ["network_states.cc"],
hdrs = ["network_states.h"],
deps = [
":alignment",
":operands",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "network_states_test",
size = "small",
srcs = ["network_states_test.cc"],
deps = [
":alignment",
":network_states",
"//dragnn/core/test:generic",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "extensions",
srcs = ["extensions.cc"],
hdrs = ["extensions.h"],
deps = [
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "extensions_test",
size = "small",
srcs = ["extensions_test.cc"],
deps = [
":extensions",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "linked_embeddings",
srcs = ["linked_embeddings.cc"],
hdrs = ["linked_embeddings.h"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":alignment",
":flexible_matrix_kernel",
":network_states",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:data_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:arithmetic",
"//dragnn/runtime/math:types",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "linked_embeddings_test",
size = "small",
srcs = ["linked_embeddings_test.cc"],
deps = [
":linked_embeddings",
":network_states",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "fixed_embeddings",
srcs = ["fixed_embeddings.cc"],
hdrs = ["fixed_embeddings.h"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":alignment",
":network_states",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:arithmetic",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "fixed_embeddings_test",
size = "small",
srcs = ["fixed_embeddings_test.cc"],
deps = [
":fixed_embeddings",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "type_keyed_set",
hdrs = ["type_keyed_set.h"],
)
cc_test(
name = "type_keyed_set_test",
size = "small",
srcs = ["type_keyed_set_test.cc"],
deps = [
":type_keyed_set",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "session_state",
hdrs = ["session_state.h"],
deps = [
":extensions",
":network_states",
],
)
cc_library(
name = "session_state_pool",
srcs = ["session_state_pool.cc"],
hdrs = ["session_state_pool.h"],
deps = [
":session_state",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "session_state_pool_test",
size = "small",
srcs = ["session_state_pool_test.cc"],
deps = [
":session_state",
":session_state_pool",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "bulk_dynamic_component",
srcs = ["bulk_dynamic_component.cc"],
deps = [
":bulk_network_unit",
":component",
":extensions",
":network_states",
":network_unit_base",
":transition_system_traits",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "bulk_dynamic_component_test",
srcs = ["bulk_dynamic_component_test.cc"],
deps = [
":bulk_dynamic_component",
":bulk_network_unit",
":component",
":extensions",
":network_states",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "sequence_bulk_dynamic_component",
srcs = ["sequence_bulk_dynamic_component.cc"],
deps = [
":bulk_network_unit",
":component",
":extensions",
":fixed_embeddings",
":linked_embeddings",
":network_states",
":sequence_model",
":session_state",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "sequence_bulk_dynamic_component_test",
srcs = ["sequence_bulk_dynamic_component_test.cc"],
deps = [
":bulk_network_unit",
":component",
":extensions",
":network_states",
":sequence_backend",
":sequence_bulk_dynamic_component",
":sequence_extractor",
":sequence_linker",
":sequence_predictor",
":variable_store",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "component",
srcs = ["component.cc"],
hdrs = ["component.h"],
deps = [
":extensions",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "component_test",
size = "small",
srcs = ["component_test.cc"],
deps = [
":component",
":extensions",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "lstm_network_kernel",
srcs = ["lstm_network_kernel.cc"],
hdrs = ["lstm_network_kernel.h"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":attributes",
":extensions",
":feed_forward_network_layer",
":flexible_matrix_kernel",
":network_states",
":session_state",
":transition_system_traits",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/lstm_cell:cell_function",
"//dragnn/runtime/math:avx_activation_functions",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "lstm_network_kernel_test",
srcs = ["lstm_network_kernel_test.cc"],
deps = [
":lstm_network_kernel",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/lstm_cell:cell_function",
"//dragnn/runtime/test:helpers",
"//dragnn/runtime/test:network_test_base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "lstm_network",
srcs = ["lstm_network.cc"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":extensions",
":lstm_network_kernel",
":network_unit",
":network_unit_base",
":transition_system_traits",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/lstm_cell:cell_function",
"//dragnn/runtime/math:avx_activation_functions",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "lstm_network_test",
srcs = ["lstm_network_test.cc"],
deps = [
":flexible_matrix_kernel",
":lstm_network",
":network_unit",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/lstm_cell:cell_function",
"//dragnn/runtime/test:network_test_base",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "bulk_lstm_network",
srcs = ["bulk_lstm_network.cc"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":bulk_network_unit",
":extensions",
":lstm_network_kernel",
":network_states",
":session_state",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "bulk_lstm_network_test",
srcs = ["bulk_lstm_network_test.cc"],
deps = [
":bulk_lstm_network",
":bulk_network_unit",
":flexible_matrix_kernel",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/lstm_cell:cell_function",
"//dragnn/runtime/test:helpers",
"//dragnn/runtime/test:network_test_base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "master",
srcs = ["master.cc"],
hdrs = ["master.h"],
deps = [
":component",
":extensions",
":network_states",
":session_state",
":session_state_pool",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "master_test",
size = "small",
srcs = ["master_test.cc"],
deps = [
":alignment",
":component",
":extensions",
":master",
":network_states",
":session_state",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/core/test:mock_compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime/test:fake_variable_store",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "network_unit",
srcs = ["network_unit.cc"],
hdrs = ["network_unit.h"],
deps = [
":extensions",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "network_unit_test",
size = "small",
srcs = ["network_unit_test.cc"],
deps = [
":extensions",
":network_states",
":network_unit",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "bulk_network_unit",
srcs = ["bulk_network_unit.cc"],
hdrs = ["bulk_network_unit.h"],
deps = [
":extensions",
":network_states",
":network_unit",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "bulk_network_unit_test",
size = "small",
srcs = ["bulk_network_unit_test.cc"],
deps = [
":bulk_network_unit",
":extensions",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "dynamic_component",
srcs = ["dynamic_component.cc"],
deps = [
":component",
":extensions",
":network_states",
":network_unit",
":session_state",
":transition_system_traits",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "dynamic_component_test",
size = "small",
srcs = ["dynamic_component_test.cc"],
deps = [
":component",
":dynamic_component",
":extensions",
":network_states",
":network_unit",
":session_state",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "network_unit_base",
srcs = ["network_unit_base.cc"],
hdrs = ["network_unit_base.h"],
deps = [
":extensions",
":fixed_embeddings",
":linked_embeddings",
":network_states",
":network_unit",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "network_unit_base_test",
size = "small",
srcs = ["network_unit_base_test.cc"],
deps = [
":extensions",
":fixed_embeddings",
":linked_embeddings",
":network_states",
":network_unit_base",
":session_state",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "attributes",
srcs = ["attributes.cc"],
hdrs = ["attributes.h"],
deps = [
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "attributes_test",
size = "small",
srcs = ["attributes_test.cc"],
deps = [
":attributes",
"//dragnn/core/test:generic",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "activation_functions",
hdrs = ["activation_functions.h"],
deps = [
"//dragnn/runtime/math:arithmetic",
"//dragnn/runtime/math:types",
],
)
cc_test(
name = "activation_functions_test",
size = "small",
srcs = ["activation_functions_test.cc"],
deps = [
":activation_functions",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:helpers",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "flexible_matrix_kernel",
srcs = ["flexible_matrix_kernel.cc"],
hdrs = ["flexible_matrix_kernel.h"],
deps = [
":alignment",
":variable_store",
"//dragnn/runtime/math:arithmetic",
"//dragnn/runtime/math:avx_vector_array",
"//dragnn/runtime/math:sgemvv",
"//dragnn/runtime/math:types",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "flexible_matrix_kernel_test",
srcs = ["flexible_matrix_kernel_test.cc"],
copts = FAST_MATH_COPTS,
deps = [
":flexible_matrix_kernel",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/runtime/math:transformations",
"//dragnn/runtime/test:fake_variable_store",
"//dragnn/runtime/test:helpers",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "feed_forward_network_layer",
srcs = ["feed_forward_network_layer.cc"],
hdrs = ["feed_forward_network_layer.h"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":activation_functions",
":flexible_matrix_kernel",
":network_states",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "feed_forward_network_layer_test",
size = "small",
srcs = ["feed_forward_network_layer_test.cc"],
deps = [
":activation_functions",
":feed_forward_network_layer",
":flexible_matrix_kernel",
":network_states",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:helpers",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "feed_forward_network_kernel",
srcs = ["feed_forward_network_kernel.cc"],
hdrs = ["feed_forward_network_kernel.h"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":activation_functions",
":attributes",
":feed_forward_network_layer",
":flexible_matrix_kernel",
":network_states",
":transition_system_traits",
":variable_store",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "feed_forward_network_kernel_test",
size = "small",
srcs = ["feed_forward_network_kernel_test.cc"],
deps = [
":activation_functions",
":feed_forward_network_kernel",
":flexible_matrix_kernel",
":network_states",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "feed_forward_network",
srcs = ["feed_forward_network.cc"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":extensions",
":feed_forward_network_kernel",
":feed_forward_network_layer",
":network_states",
":network_unit",
":network_unit_base",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "feed_forward_network_test",
size = "small",
srcs = ["feed_forward_network_test.cc"],
deps = [
":dynamic_component",
":feed_forward_network",
":flexible_matrix_kernel",
":network_states",
":network_unit",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "bulk_feed_forward_network",
srcs = ["bulk_feed_forward_network.cc"],
copts = FAST_MATH_COPTS,
opts_self = True,
deps = [
":bulk_network_unit",
":extensions",
":feed_forward_network_kernel",
":feed_forward_network_layer",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "bulk_feed_forward_network_test",
size = "small",
srcs = ["bulk_feed_forward_network_test.cc"],
deps = [
":bulk_feed_forward_network",
":bulk_network_unit",
":dynamic_component",
":flexible_matrix_kernel",
":network_states",
":variable_store",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "conversion",
srcs = ["conversion.cc"],
hdrs = ["conversion.h"],
deps = [
":array_variable_store_builder",
":master",
":trained_model_variable_store",
":variable_store",
":variable_store_wrappers",
"//dragnn/protos:runtime_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "conversion_test",
size = "small",
timeout = "moderate",
srcs = ["conversion_test.cc"],
data = [
"testdata/conversion_output_variables_data",
"testdata/conversion_output_variables_spec",
":test_rnn_tagger",
],
shard_count = 6,
deps = [
":conversion",
":dynamic_component",
":feed_forward_network",
":lstm_network",
"//dragnn/components/syntaxnet:syntaxnet_component",
"//dragnn/core/test:generic",
"//dragnn/protos:runtime_proto_cc",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "component_transformation",
srcs = ["component_transformation.cc"],
hdrs = ["component_transformation.h"],
deps = [
":component",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "component_transformation_test",
size = "small",
srcs = ["component_transformation_test.cc"],
deps = [
":component_transformation",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "fml_parsing",
srcs = ["fml_parsing.cc"],
hdrs = ["fml_parsing.h"],
deps = [
":attributes",
"//syntaxnet:base",
"//syntaxnet:feature_extractor_proto_cc",
"//syntaxnet:fml_parser",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "fml_parsing_test",
size = "small",
srcs = ["fml_parsing_test.cc"],
deps = [
":fml_parsing",
"//dragnn/core/test:generic",
"//syntaxnet:base",
"//syntaxnet:feature_extractor_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "term_map_utils",
srcs = ["term_map_utils.cc"],
hdrs = ["term_map_utils.h"],
deps = [
":fml_parsing",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:feature_extractor_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "term_map_utils_test",
size = "small",
srcs = ["term_map_utils_test.cc"],
deps = [
":term_map_utils",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "transition_system_traits",
srcs = ["transition_system_traits.cc"],
hdrs = ["transition_system_traits.h"],
deps = [
":attributes",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "transition_system_traits_test",
size = "small",
srcs = ["transition_system_traits_test.cc"],
deps = [
":transition_system_traits",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "unicode_dictionary",
srcs = ["unicode_dictionary.cc"],
hdrs = ["unicode_dictionary.h"],
deps = [
"//syntaxnet:base",
"//syntaxnet:term_frequency_map",
"//util/utf8:unicodetext",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "unicode_dictionary_test",
size = "small",
timeout = "moderate",
srcs = ["unicode_dictionary_test.cc"],
deps = [
":unicode_dictionary",
"//dragnn/core/test:generic",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"//syntaxnet:term_frequency_map",
"//third_party/utf",
"//util/utf8:unicodetext",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "sequence_extractor",
srcs = ["sequence_extractor.cc"],
hdrs = ["sequence_extractor.h"],
deps = [
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "sequence_extractor_test",
size = "small",
srcs = ["sequence_extractor_test.cc"],
deps = [
":sequence_extractor",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "term_map_sequence_extractor",
hdrs = ["term_map_sequence_extractor.h"],
deps = [
":sequence_extractor",
":term_map_utils",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:shared_store",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "term_map_sequence_extractor_test",
size = "small",
srcs = ["term_map_sequence_extractor_test.cc"],
deps = [
":term_map_sequence_extractor",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"//syntaxnet:term_frequency_map",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "syntaxnet_character_sequence_extractor",
srcs = ["syntaxnet_character_sequence_extractor.cc"],
deps = [
":sequence_extractor",
":term_map_sequence_extractor",
":term_map_utils",
":transition_system_traits",
":unicode_dictionary",
"//dragnn/core:input_batch_cache",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:segmenter_utils",
"//util/utf8:unicodetext",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "syntaxnet_character_sequence_extractor_test",
size = "small",
srcs = ["syntaxnet_character_sequence_extractor_test.cc"],
deps = [
":sequence_extractor",
":syntaxnet_character_sequence_extractor",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "syntaxnet_word_sequence_extractor",
srcs = ["syntaxnet_word_sequence_extractor.cc"],
deps = [
":sequence_extractor",
":term_map_sequence_extractor",
":term_map_utils",
":transition_system_traits",
"//dragnn/core:input_batch_cache",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:term_frequency_map",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "syntaxnet_word_sequence_extractor_test",
size = "small",
srcs = ["syntaxnet_word_sequence_extractor_test.cc"],
deps = [
":sequence_extractor",
":syntaxnet_word_sequence_extractor",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "sequence_features",
srcs = ["sequence_features.cc"],
hdrs = ["sequence_features.h"],
deps = [
":alignment",
":fixed_embeddings",
":sequence_extractor",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "sequence_features_test",
size = "small",
srcs = ["sequence_features_test.cc"],
deps = [
":fixed_embeddings",
":sequence_extractor",
":sequence_features",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "sequence_linker",
srcs = ["sequence_linker.cc"],
hdrs = ["sequence_linker.h"],
deps = [
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "sequence_linker_test",
size = "small",
srcs = ["sequence_linker_test.cc"],
deps = [
":sequence_linker",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "identity_sequence_linker",
srcs = ["identity_sequence_linker.cc"],
deps = [
":sequence_linker",
":transition_system_traits",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "identity_sequence_linker_test",
size = "small",
srcs = ["identity_sequence_linker_test.cc"],
deps = [
":identity_sequence_linker",
":sequence_linker",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "reversed_sequence_linker",
srcs = ["reversed_sequence_linker.cc"],
deps = [
":sequence_linker",
":transition_system_traits",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "reversed_sequence_linker_test",
size = "small",
srcs = ["reversed_sequence_linker_test.cc"],
deps = [
":reversed_sequence_linker",
":sequence_linker",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "recurrent_sequence_linkers",
srcs = ["recurrent_sequence_linkers.cc"],
deps = [
":sequence_linker",
":transition_system_traits",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "recurrent_sequence_linkers_test",
size = "small",
srcs = ["recurrent_sequence_linkers_test.cc"],
deps = [
":recurrent_sequence_linkers",
":sequence_linker",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "syntaxnet_character_sequence_linkers",
srcs = ["syntaxnet_character_sequence_linkers.cc"],
deps = [
":sequence_linker",
":transition_system_traits",
"//dragnn/core:input_batch_cache",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//util/utf8:unicodetext",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "syntaxnet_character_sequence_linkers_test",
size = "small",
srcs = ["syntaxnet_character_sequence_linkers_test.cc"],
deps = [
":sequence_linker",
":syntaxnet_character_sequence_linkers",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "sequence_links",
srcs = ["sequence_links.cc"],
hdrs = ["sequence_links.h"],
deps = [
":alignment",
":linked_embeddings",
":network_states",
":sequence_linker",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "sequence_links_test",
size = "small",
srcs = ["sequence_links_test.cc"],
deps = [
":linked_embeddings",
":network_states",
":sequence_linker",
":sequence_links",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "sequence_predictor",
srcs = ["sequence_predictor.cc"],
hdrs = ["sequence_predictor.h"],
deps = [
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"//syntaxnet:registry",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "sequence_predictor_test",
size = "small",
srcs = ["sequence_predictor_test.cc"],
deps = [
":sequence_predictor",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "term_map_sequence_predictor",
srcs = ["term_map_sequence_predictor.cc"],
hdrs = ["term_map_sequence_predictor.h"],
deps = [
":sequence_predictor",
":term_map_utils",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:shared_store",
"//syntaxnet:term_frequency_map",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "term_map_sequence_predictor_test",
size = "small",
srcs = ["term_map_sequence_predictor_test.cc"],
deps = [
":term_map_sequence_predictor",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"//syntaxnet:test_main",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "syntaxnet_tag_sequence_predictor",
srcs = ["syntaxnet_tag_sequence_predictor.cc"],
deps = [
":sequence_predictor",
":term_map_sequence_predictor",
":transition_system_traits",
"//dragnn/core:input_batch_cache",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"//syntaxnet:term_frequency_map",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "syntaxnet_tag_sequence_predictor_test",
size = "small",
srcs = ["syntaxnet_tag_sequence_predictor_test.cc"],
deps = [
":alignment",
":sequence_predictor",
":syntaxnet_tag_sequence_predictor",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/test:helpers",
"//dragnn/runtime/test:term_map_helpers",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "sequence_backend",
srcs = ["sequence_backend.cc"],
hdrs = ["sequence_backend.h"],
deps = [
"//dragnn/core:component_registry",
"//dragnn/core:input_batch_cache",
"//dragnn/core/interfaces:component",
"//dragnn/core/interfaces:transition_state",
"//dragnn/core/util:label",
"//dragnn/protos:data_proto_cc",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "sequence_backend_test",
size = "small",
srcs = ["sequence_backend_test.cc"],
deps = [
":sequence_backend",
"//dragnn/components/util:bulk_feature_extractor",
"//dragnn/core:input_batch_cache",
"//dragnn/core/interfaces:transition_state",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "select_best_component_transformer",
srcs = ["select_best_component_transformer.cc"],
deps = [
":component",
":component_transformation",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "select_best_component_transformer_test",
size = "small",
srcs = ["select_best_component_transformer_test.cc"],
deps = [
":component",
":component_transformation",
":extensions",
":select_best_component_transformer",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "sequence_component_transformer",
srcs = ["sequence_component_transformer.cc"],
deps = [
":component_transformation",
":sequence_extractor",
":sequence_linker",
":sequence_predictor",
":transition_system_traits",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "sequence_component_transformer_test",
size = "small",
srcs = ["sequence_component_transformer_test.cc"],
deps = [
":component_transformation",
":sequence_component_transformer",
":sequence_extractor",
":sequence_linker",
":sequence_predictor",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "stateless_component_transformer",
srcs = ["stateless_component_transformer.cc"],
deps = [
":component_transformation",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "stateless_component_transformer_test",
size = "small",
srcs = ["stateless_component_transformer_test.cc"],
deps = [
":component_transformation",
":stateless_component_transformer",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "clear_dropout_component_transformer",
srcs = ["clear_dropout_component_transformer.cc"],
deps = [
":component_transformation",
"//dragnn/protos:spec_proto_cc",
"//syntaxnet:base",
"//syntaxnet:feature_extractor_proto_cc",
"//syntaxnet:fml_parser",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "clear_dropout_component_transformer_test",
size = "small",
srcs = ["clear_dropout_component_transformer_test.cc"],
deps = [
":clear_dropout_component_transformer",
":component_transformation",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "sequence_model",
srcs = ["sequence_model.cc"],
hdrs = ["sequence_model.h"],
deps = [
":attributes",
":fixed_embeddings",
":linked_embeddings",
":network_states",
":sequence_backend",
":sequence_features",
":sequence_links",
":sequence_predictor",
":session_state",
":transition_system_traits",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
dragnn_cc_multiarch_test(
name = "sequence_model_test",
size = "small",
srcs = ["sequence_model_test.cc"],
deps = [
":fixed_embeddings",
":linked_embeddings",
":network_states",
":sequence_backend",
":sequence_extractor",
":sequence_linker",
":sequence_model",
":sequence_predictor",
":session_state",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
dragnn_cc_multiarch_library(
name = "biaffine_digraph_component",
srcs = ["biaffine_digraph_component.cc"],
copts = FAST_MATH_COPTS,
deps = [
":component",
":extensions",
":network_states",
":network_unit",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime/math:eigen",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
dragnn_cc_multiarch_test(
name = "biaffine_digraph_component_test",
size = "small",
srcs = ["biaffine_digraph_component_test.cc"],
deps = [
":biaffine_digraph_component",
":component",
":extensions",
":network_states",
":session_state",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "head_selection_component_base",
srcs = ["head_selection_component_base.cc"],
hdrs = ["head_selection_component_base.h"],
deps = [
":alignment",
":component",
":extensions",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "head_selection_component_base_test",
size = "small",
srcs = ["head_selection_component_base_test.cc"],
deps = [
":head_selection_component_base",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "syntaxnet_head_selection_component",
srcs = ["syntaxnet_head_selection_component.cc"],
deps = [
":head_selection_component_base",
":session_state",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "syntaxnet_head_selection_component_test",
size = "small",
srcs = ["syntaxnet_head_selection_component_test.cc"],
deps = [
":component",
":syntaxnet_head_selection_component",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "mst_solver_component_base",
srcs = ["mst_solver_component_base.cc"],
hdrs = ["mst_solver_component_base.h"],
deps = [
":attributes",
":component",
":extensions",
":network_states",
":network_unit",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/mst:mst_solver",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
],
)
cc_test(
name = "mst_solver_component_base_test",
size = "small",
srcs = ["mst_solver_component_base_test.cc"],
deps = [
":mst_solver_component_base",
":network_states",
":session_state",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/protos:spec_proto_cc",
"//dragnn/protos:trace_proto_cc",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "syntaxnet_mst_solver_component",
srcs = ["syntaxnet_mst_solver_component.cc"],
deps = [
":mst_solver_component_base",
":session_state",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:trace_proto_cc",
"//syntaxnet:base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
],
alwayslink = 1,
)
cc_test(
name = "syntaxnet_mst_solver_component_test",
size = "small",
srcs = ["syntaxnet_mst_solver_component_test.cc"],
deps = [
":component",
":syntaxnet_mst_solver_component",
":variable_store",
"//dragnn/core:compute_session",
"//dragnn/core:input_batch_cache",
"//dragnn/core/test:generic",
"//dragnn/io:sentence_input_batch",
"//dragnn/io:syntaxnet_sentence",
"//dragnn/protos:spec_proto_cc",
"//dragnn/runtime/math:types",
"//dragnn/runtime/test:network_test_base",
"//syntaxnet:sentence_proto_cc",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
],
)
cc_library(
name = "converter_main",
srcs = ["converter.cc"],
deps = [
":component_transformation",
":conversion",
"//dragnn/runtime/myelin:myelination",
"//dragnn/runtime/xla:xla_compilation",
"//syntaxnet:base",
"@org_tensorflow//tensorflow/core:lib",
"@sling//sling/base",
],
)
dragnn_cc_multiarch_binary(
name = "converter",
target_arch = "generic",
deps = [
":biaffine_digraph_component",
":bulk_dynamic_component",
":bulk_feed_forward_network",
":bulk_lstm_network",
":clear_dropout_component_transformer",
":converter_main",
":dynamic_component",
":feed_forward_network",
":identity_sequence_linker",
":lstm_network",
":recurrent_sequence_linkers",
":reversed_sequence_linker",
":select_best_component_transformer",
":sequence_backend",
":sequence_bulk_dynamic_component",
":sequence_component_transformer",
":stateless_component_transformer",
":syntaxnet_character_sequence_extractor",
":syntaxnet_character_sequence_linkers",
":syntaxnet_head_selection_component",
":syntaxnet_mst_solver_component",
":syntaxnet_tag_sequence_predictor",
":syntaxnet_word_sequence_extractor",
"//dragnn/components/stateless:stateless_component",
"//dragnn/components/syntaxnet:syntaxnet_component",
"//dragnn/mst:mst_ops_cc",
"//dragnn/runtime/myelin:myelin_dynamic_component",
"//dragnn/runtime/myelin:sequence_myelin_dynamic_component",
"//dragnn/runtime/xla:xla_dynamic_component",
"//syntaxnet:parser_transitions",
],
)
sh_test(
name = "converter_test",
size = "medium",
srcs = ["converter_test.sh"],
data = [":converter"] + glob([
"testdata/converter_output/**",
"testdata/rnn_tagger/**",
]),
)
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Definitions of activation functions for neural netowrks.
#ifndef DRAGNN_RUNTIME_ACTIVATION_FUNCTIONS_H_
#define DRAGNN_RUNTIME_ACTIVATION_FUNCTIONS_H_
#include "dragnn/runtime/math/arithmetic.h"
#include "dragnn/runtime/math/types.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Possible types of activation functions.
//
// TODO(googleuser): If many activation functions are added, or if functions start
// using configuration parameters (e.g., leakiness of a leaky ReLU), then switch
// to a registered class.
enum class ActivationFunction {
kIdentity, // pass-through, useful for classification logits
kRelu, // ReLU; i.e., max(0,x)
};
// Applies the |activation_function| to the |values|.
template <class T>
void ApplyActivationFunction(ActivationFunction activation_function,
MutableVector<T> values);
// Implementation details below.
template <class T>
void ApplyActivationFunction(ActivationFunction activation_function,
MutableVector<T> values) {
switch (activation_function) {
case ActivationFunction::kIdentity:
break;
case ActivationFunction::kRelu:
MaxElements(T(), values);
break;
}
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_ACTIVATION_FUNCTIONS_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/activation_functions.h"
#include "dragnn/runtime/math/types.h"
#include "dragnn/runtime/test/helpers.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
// Tests that kIdentity is a pass-through.
TEST(ActivationFunctionsTest, ApplyIdentity) {
UniqueVector<float> values({1.25f, -1.5f, 0.0f, 0.0625f, -0.03125});
ApplyActivationFunction(ActivationFunction::kIdentity, *values);
EXPECT_EQ((*values)[0], 1.25);
EXPECT_EQ((*values)[1], -1.5);
EXPECT_EQ((*values)[2], 0.0);
EXPECT_EQ((*values)[3], 0.0625);
EXPECT_EQ((*values)[4], -0.03125);
}
// Tests that kRelu clips to zero.
TEST(ActivationFunctionsTest, ApplyRelu) {
UniqueVector<float> values({1.25f, -1.5f, 0.0f, 0.0625f, -0.03125});
ApplyActivationFunction(ActivationFunction::kRelu, *values);
EXPECT_EQ((*values)[0], 1.25);
EXPECT_EQ((*values)[1], 0.0); // clipped
EXPECT_EQ((*values)[2], 0.0); // boundary
EXPECT_EQ((*values)[3], 0.0625);
EXPECT_EQ((*values)[4], 0.0); // clipped
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
// Utils for working with aligned memory blocks. The DRAGNN runtime requires
// aligned memory for use in vectorized math. Do not rely on any particular
// value of the alignment requirement, because it will vary over time and in
// different build configurations.
#ifndef DRAGNN_RUNTIME_ALIGNMENT_H_
#define DRAGNN_RUNTIME_ALIGNMENT_H_
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <memory>
#include <type_traits>
#include <vector>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
// This is a type that has some private methods (so non-POD), but is known to be
// trivially-deconstructable. Ergo we add some special handling so
// IsAlignable<bfloat16> returns true.
namespace tensorflow {
struct bfloat16;
}
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Returns true if |T| can be used in an aligned memory block.
template <class T>
constexpr bool IsAlignable();
// Returns OK iff the |pointer| satisfies the alignment requirement.
tensorflow::Status OkIfAligned(const void *pointer);
// Returns the next alignment boundary at or after the |byte_offset|.
size_t PadToAlignment(size_t byte_offset);
// As above, but for pointers.
template <class T>
T *PadToAlignment(T *pointer);
// Returns the number of bytes required to store a sequence of |num_arrays|
// aligned arrays of |array_size| bytes, including alignment padding. See
// (Mutable)AlignedArea below.
size_t ComputeAlignedAreaSize(size_t num_arrays, size_t array_size);
// Returns the number of bytes required to store a sequence of byte arrays of
// the given |sizes|, including alignment padding after each array.
size_t ComputeTotalBytesWithAlignmentPadding(const std::vector<size_t> &sizes);
// Forward-declared for friendship below.
class Operands;
class UniqueAlignedArray;
enum class BlockedMatrixFormat;
namespace internal {
// A non-owning view of an aligned byte array. Templated so const and mutable
// versions can share implementation. Do not use this class directly, instead
// use (Mutable)AlignedView below.
template <class Byte>
class AlignedViewImpl {
public:
static_assert(sizeof(Byte) == 1, "Byte must be byte-sized");
// Creates an empty view.
AlignedViewImpl() = default;
// Points this at the same bytes as |that|, possibly reinterpreting type.
template <class OtherByte>
explicit AlignedViewImpl(AlignedViewImpl<OtherByte> that);
template <class OtherByte>
AlignedViewImpl &operator=(AlignedViewImpl<OtherByte> that);
// Points this at [|data|,|data|+|size|). On error, returns non-OK and
// modifies nothing.
tensorflow::Status Reset(Byte *data, size_t size);
// Splits this into a list of |views| of the |sizes|, possibly reinterpreting
// type. The |views| need not completely cover all bytes of this. Requires
// that this spans ComputeTotalBytesWithAlignmentPadding(|sizes|) bytes. On
// error, returns non-OK and modifies nothing.
template <class OtherByte>
tensorflow::Status Split(
const std::vector<size_t> &sizes,
std::vector<AlignedViewImpl<OtherByte>> *views) const;
// Accessors.
Byte *data() const { return data_; }
size_t size() const { return size_; }
bool empty() const { return size() == 0; }
private:
template <class OtherByte>
friend class AlignedViewImpl;
template <class OtherByte>
friend class AlignedAreaImpl;
friend Operands;
friend UniqueAlignedArray;
// Directly creates an aligned view, bypassing alignment checks.
AlignedViewImpl(Byte *data, size_t size);
// Pointer to the start of the view.
Byte *data_ = nullptr;
// Number of bytes in the view.
size_t size_ = 0;
};
// A non-owning view of an aligned, 2-dimensional byte array. Templated so
// const and mutable versons can share implementation. Do not use this class
// directly, instead use (Mutable)AlignedArea below.
template <class Byte>
class AlignedAreaImpl {
public:
static_assert(sizeof(Byte) == 1, "Byte must be byte-sized");
// Creates an empty area.
AlignedAreaImpl() = default;
// Points this at the same bytes as |that|, possibly reinterpreting type.
template <class OtherByte>
explicit AlignedAreaImpl(AlignedAreaImpl<OtherByte> that);
template <class OtherByte>
AlignedAreaImpl &operator=(AlignedAreaImpl<OtherByte> that);
// Resets this to a sequence of |num_views| aligned sub-views of the |view|,
// each |view_size| bytes wide. The first sub-view covers [0,|view_size|) of
// |view|, and each subsequent sub-view starts at the next alignment boundary.
// Requires that |view| spans ComputeAlignedAreaSize(|num_views|,|view_size|)
// bytes or more. On error, returns non-OK and modifies nothing.
template <class OtherByte>
tensorflow::Status Reset(AlignedViewImpl<OtherByte> view, size_t num_views,
size_t view_size);
// Accessors.
AlignedViewImpl<Byte> view(size_t index) const;
Byte *data() const { return data_; }
size_t num_views() const { return num_views_; }
size_t view_size() const { return view_size_; }
size_t view_stride() const { return view_stride_; }
bool empty() const { return num_views() == 0; }
private:
template <class OtherByte>
friend class AlignedAreaImpl;
friend Operands;
// Directly creates an aligned view, bypassing alignment checks.
AlignedAreaImpl(Byte *data, size_t num_views, size_t view_size,
size_t view_stride);
// Pointer to the start of the first view.
Byte *data_ = nullptr;
// Number of views in the area.
size_t num_views_ = 0;
// Size of each view in bytes, excluding alignment padding.
size_t view_size_ = 0;
// Number of bytes between the starts of consecutive views. NB: This is not
// necessarily equal to PadToAlignment(|view_size_|).
size_t view_stride_ = 0;
};
} // namespace internal
// Public aliases; use these.
using AlignedView = internal::AlignedViewImpl<const char>;
using AlignedArea = internal::AlignedAreaImpl<const char>;
using MutableAlignedView = internal::AlignedViewImpl<char>;
using MutableAlignedArea = internal::AlignedAreaImpl<char>;
// A uniquely-owned aligned byte array.
class UniqueAlignedArray {
public:
// Creates an empty byte array.
UniqueAlignedArray() = default;
// Reallocates this to |new_size| bytes, and discards the current byte array.
// Contents are uninitialized.
void Reset(size_t new_size);
// Like Reset(), but only reallocates if |new_size| is more than the current
// capacity. NB: Does not preserve current content when reallocation occurs;
// use Resize() if that is desired.
void Reserve(size_t new_size);
// Resizes this to contain |new_size| bytes, preserving current content. If
// |new_size| exceeds the current size, the added bytes are uninitialized. If
// |new_size| exceeds the current capacity, reallocates, and copies current
// content. Returns true if reallocation occurred.
bool Resize(size_t new_size);
// Returns the aligned byte array.
MutableAlignedView view() const { return view_; }
private:
// Underlying byte array, which is padded for alignment.
std::unique_ptr<char[]> padded_array_;
// Size of the aligned portion of |padded_array_|.
size_t capacity_ = 0;
// Active range of the |storage_|.
MutableAlignedView view_;
};
// Implementation details below.
namespace internal {
// Required alignment for memory blocks. Only the runtime framework should use
// this; otherwise, DO NOT access or otherwise depend on this value.
enum : size_t { kAlignmentBytes = 32 };
} // namespace internal
template <class T>
constexpr bool IsAlignable() {
// Either T is divisible into alignment windows, or an alignment window is
// divisible into Ts. Likewise for T's alignment requirement. Finally, T
// must be POD because we won't call its constructor or destructor.
return (sizeof(T) % internal::kAlignmentBytes == 0 ||
internal::kAlignmentBytes % sizeof(T) == 0) &&
(alignof(T) % internal::kAlignmentBytes == 0 ||
internal::kAlignmentBytes % alignof(T) == 0) &&
(std::is_pod<T>::value ||
std::is_same<T, tensorflow::bfloat16>::value);
}
inline tensorflow::Status OkIfAligned(const void *pointer) {
const uintptr_t address = reinterpret_cast<uintptr_t>(pointer);
if (address % internal::kAlignmentBytes != 0) {
return tensorflow::errors::InvalidArgument(
"Pointer fails alignment requirement: ", address, " vs required ",
internal::kAlignmentBytes);
}
return tensorflow::Status::OK();
}
inline size_t PadToAlignment(size_t byte_offset) {
// Round up to the next alignment boundary by incrementing by a certain amount
// and then rounding down. Note that the bitmask clears the low-order bits of
// the offset, effectively rounding down to the previous alignment boundary.
return (byte_offset + internal::kAlignmentBytes - 1) &
~(internal::kAlignmentBytes - 1);
}
template <class T>
T *PadToAlignment(T *pointer) {
static_assert(IsAlignable<T>(), "T is not alignable");
uintptr_t address = reinterpret_cast<uintptr_t>(pointer);
address = (address + internal::kAlignmentBytes - 1) &
~(internal::kAlignmentBytes - 1);
return reinterpret_cast<T *>(address);
}
inline size_t ComputeAlignedAreaSize(size_t num_arrays, size_t array_size) {
return num_arrays * PadToAlignment(array_size);
}
inline size_t ComputeTotalBytesWithAlignmentPadding(
const std::vector<size_t> &sizes) {
size_t total = 0;
for (const size_t size : sizes) total += PadToAlignment(size);
return total;
}
namespace internal {
template <class Byte>
template <class OtherByte>
AlignedViewImpl<Byte>::AlignedViewImpl(AlignedViewImpl<OtherByte> that)
: data_(reinterpret_cast<Byte *>(that.data())), size_(that.size()) {}
template <class Byte>
template <class OtherByte>
AlignedViewImpl<Byte> &AlignedViewImpl<Byte>::operator=(
AlignedViewImpl<OtherByte> that) {
data_ = reinterpret_cast<Byte *>(that.data());
size_ = that.size();
return *this;
}
template <class Byte>
tensorflow::Status AlignedViewImpl<Byte>::Reset(Byte *data, size_t size) {
TF_RETURN_IF_ERROR(OkIfAligned(data));
// Success; make modifications.
data_ = data;
size_ = size;
return tensorflow::Status::OK();
}
template <class Byte>
template <class OtherByte>
tensorflow::Status AlignedViewImpl<Byte>::Split(
const std::vector<size_t> &sizes,
std::vector<AlignedViewImpl<OtherByte>> *views) const {
const size_t total_bytes = ComputeTotalBytesWithAlignmentPadding(sizes);
if (size() < total_bytes) {
return tensorflow::errors::InvalidArgument(
"View is too small to be split into sizes [",
tensorflow::str_util::Join(sizes, ", "), "]: need ", total_bytes,
" bytes but have ", size(), " bytes");
}
// Success; make modifications.
views->clear();
views->reserve(sizes.size());
Byte *base = data();
for (const size_t size : sizes) {
views->push_back(AlignedViewImpl<OtherByte>(base, size));
base = PadToAlignment(base + size);
}
DCHECK_EQ(base - data(), total_bytes);
return tensorflow::Status::OK();
}
template <class Byte>
AlignedViewImpl<Byte>::AlignedViewImpl(Byte *data, size_t size)
: data_(data), size_(size) {
TF_DCHECK_OK(OkIfAligned(data_));
}
template <class Byte>
template <class OtherByte>
AlignedAreaImpl<Byte>::AlignedAreaImpl(AlignedAreaImpl<OtherByte> that)
: data_(reinterpret_cast<Byte *>(that.data_)),
num_views_(that.num_views()),
view_size_(that.view_size()),
view_stride_(that.view_stride_) {}
template <class Byte>
template <class OtherByte>
AlignedAreaImpl<Byte> &AlignedAreaImpl<Byte>::operator=(
AlignedAreaImpl<OtherByte> that) {
data_ = reinterpret_cast<Byte *>(that.data_);
num_views_ = that.num_views();
view_size_ = that.view_size();
view_stride_ = that.view_stride_;
return *this;
}
template <class Byte>
template <class OtherByte>
tensorflow::Status AlignedAreaImpl<Byte>::Reset(AlignedViewImpl<OtherByte> view,
size_t num_views,
size_t view_size) {
const size_t total_bytes = ComputeAlignedAreaSize(num_views, view_size);
if (view.size() < total_bytes) {
return tensorflow::errors::InvalidArgument(
"View is too small for area of ", num_views, " views of ", view_size,
" bytes: need ", total_bytes, " bytes but got ", view.size(), " bytes");
}
// Success; make modifications.
data_ = reinterpret_cast<Byte *>(view.data());
num_views_ = num_views;
view_size_ = view_size;
view_stride_ = PadToAlignment(view_size_);
return tensorflow::Status::OK();
}
template <class Byte>
AlignedViewImpl<Byte> AlignedAreaImpl<Byte>::view(size_t index) const {
DCHECK_LT(index, num_views());
return AlignedViewImpl<Byte>(data_ + view_stride_ * index, view_size_);
}
template <class Byte>
AlignedAreaImpl<Byte>::AlignedAreaImpl(Byte *data, size_t num_views,
size_t view_size, size_t view_stride)
: data_(data),
num_views_(num_views),
view_size_(view_size),
view_stride_(view_stride) {
TF_DCHECK_OK(OkIfAligned(data_));
TF_DCHECK_OK(OkIfAligned(static_cast<const char *>(nullptr) + view_stride_));
}
} // namespace internal
inline void UniqueAlignedArray::Reset(size_t new_size) {
// Pad the |new_size| to the next alignment boundary, so the final bytes of
// the array are still in a full alignment window. E.g., if we resize to 48
// bytes with 32-byte alignment, then we allocate 64 bytes so the final 16
// bytes are still part of a full 32-byte alignment window.
const size_t aligned_size = PadToAlignment(new_size);
// To obtain an aligned address, allocate a sufficiently-padded byte array and
// find an aligned address near the start of the block.
//
// TODO(googleuser): Alternatively, we could use library functions such as
// memalign(), posix_memalign(), or aligned_alloc(), but those may not be
// present on all platforms. Consider adding some #ifs to allow use of those
// library functions when available.
padded_array_.reset(new char[aligned_size + internal::kAlignmentBytes - 1]);
capacity_ = aligned_size;
view_.size_ = new_size;
view_.data_ = PadToAlignment(padded_array_.get());
TF_DCHECK_OK(OkIfAligned(view_.data_));
}
inline void UniqueAlignedArray::Reserve(size_t new_size) {
if (new_size > capacity_) {
Reset(new_size);
} else {
view_.size_ = new_size;
}
}
inline bool UniqueAlignedArray::Resize(size_t new_size) {
// Avoid reallocation, if possible.
if (new_size <= capacity_) {
view_.size_ = new_size;
return false;
}
// Reallocate and copy. Extend the life of the old array until it is copied.
//
// Note: C realloc() can extend a byte array in place (i.e., without copying).
// Unfortunately, there is no aligned version of realloc(). Moreover, adding
// alignment padding could cause double-copying: first, when realloc() copies
// the data to the new buffer, and second, if the amount of padding required
// at the new address is not the same as before.
const std::unique_ptr<char[]> old_array = std::move(padded_array_);
const MutableAlignedView old_view = view_;
Reset(2 * new_size);
memcpy(view_.data(), old_view.data(), old_view.size());
view_.size_ = new_size;
return true;
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
#endif // DRAGNN_RUNTIME_ALIGNMENT_H_
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/alignment.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "dragnn/core/test/generic.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
namespace {
static_assert(internal::kAlignmentBytes >= 4, "alignment too small");
// Expects that two pointers have the same address.
void ExpectSameAddress(const void *pointer1, const void *pointer2) {
EXPECT_EQ(pointer1, pointer2);
}
// Tests that standard scalar types are alignable.
TEST(IsAlignableTest, Alignable) {
EXPECT_TRUE(IsAlignable<char>());
EXPECT_TRUE(IsAlignable<float>());
EXPECT_TRUE(IsAlignable<double>());
}
// Tests that objects of odd sizes are not alignable.
TEST(IsAlignableTest, NotAlignable) {
EXPECT_FALSE(IsAlignable<char[3]>());
EXPECT_FALSE(IsAlignable<char[7]>());
EXPECT_FALSE(IsAlignable<char[7919]>());
}
// Tests that OkIfAligned() returns OK on aligned pointers.
TEST(OkIfAlignedTest, Aligned) {
const char *ptr = nullptr;
TF_EXPECT_OK(OkIfAligned(ptr));
ptr += internal::kAlignmentBytes;
TF_EXPECT_OK(OkIfAligned(ptr));
ptr += 123 * internal::kAlignmentBytes;
TF_EXPECT_OK(OkIfAligned(ptr));
}
// Tests that OkIfAligned() returns non-OK on misaligned pointers.
TEST(OkIfAlignedTest, NotAligned) {
const char *ptr = nullptr;
EXPECT_THAT(OkIfAligned(ptr + 1),
test::IsErrorWithSubstr("Pointer fails alignment requirement"));
EXPECT_THAT(OkIfAligned(ptr + 23),
test::IsErrorWithSubstr("Pointer fails alignment requirement"));
}
// Tests that any window of |internal::kAlignmentBytes| bytes contains exactly
// one aligned address.
TEST(OkIfAlignedTest, OnePerAlignmentWindow) {
// Note that |bytes| does not necessarily start at an aligned address. Even
// so, it is still guaranteed to contain exactly one aligned address, in the
// same sense that any sequence of 10 consecutive integers contains exactly
// one whose decimal representation ends in '0'. This property is exploited
// in UniqueAlignedArray::Reset().
const string bytes(internal::kAlignmentBytes, ' ');
int num_ok = 0;
for (int i = 0; i < bytes.size(); ++i) {
if (OkIfAligned(bytes.data() + i).ok()) ++num_ok;
}
EXPECT_EQ(num_ok, 1);
}
// Tests that PadToAlignment() produces an aligned byte offset.
TEST(PadToAlignmentTest, Offset) {
EXPECT_EQ(PadToAlignment(0), 0);
EXPECT_EQ(PadToAlignment(1), internal::kAlignmentBytes);
EXPECT_EQ(PadToAlignment(internal::kAlignmentBytes + 1),
2 * internal::kAlignmentBytes);
EXPECT_EQ(PadToAlignment(99 * internal::kAlignmentBytes + 3),
100 * internal::kAlignmentBytes);
}
// Tests that PadToAlignment() produces an aligned pointer.
TEST(PadToAlignmentTest, Pointer) {
const string bytes = "hello";
TF_EXPECT_OK(OkIfAligned(PadToAlignment(bytes.data())));
const std::vector<float> reals(10);
TF_EXPECT_OK(OkIfAligned(PadToAlignment(reals.data())));
}
// Tests that ComputeAlignedAreaSize() calculates the correct size.
TEST(ComputeAlignedAreaSizeTest, Basic) {
EXPECT_EQ(ComputeAlignedAreaSize(0, 0), 0);
EXPECT_EQ(ComputeAlignedAreaSize(0, 1), 0);
EXPECT_EQ(ComputeAlignedAreaSize(1, 0), 0);
EXPECT_EQ(ComputeAlignedAreaSize(1, 1), internal::kAlignmentBytes);
EXPECT_EQ(ComputeAlignedAreaSize(1, internal::kAlignmentBytes),
internal::kAlignmentBytes);
EXPECT_EQ(ComputeAlignedAreaSize(3, internal::kAlignmentBytes + 1),
6 * internal::kAlignmentBytes);
EXPECT_EQ(ComputeAlignedAreaSize(11, internal::kAlignmentBytes - 1),
11 * internal::kAlignmentBytes);
EXPECT_EQ(ComputeAlignedAreaSize(7, internal::kAlignmentBytes),
7 * internal::kAlignmentBytes);
}
// Tests that ComputeTotalBytesWithAlignmentPadding() calculates the correct
// total size.
TEST(ComputeTotalBytesWithAlignmentPaddingTest, DifferentSizes) {
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({}), 0);
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({0}), 0);
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({0, 0, 0}), 0);
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({1}),
internal::kAlignmentBytes);
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({1, 1, 1}),
3 * internal::kAlignmentBytes);
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding(
{1, internal::kAlignmentBytes, internal::kAlignmentBytes + 1}),
4 * internal::kAlignmentBytes);
std::vector<size_t> sizes;
for (size_t i = 1; i <= internal::kAlignmentBytes; ++i) sizes.push_back(i);
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding(sizes),
internal::kAlignmentBytes * internal::kAlignmentBytes);
}
// Tests that ComputeTotalBytesWithAlignmentPadding() is equivalent to
// ComputeAlignedAreaSize() when all sizes are equal.
TEST(ComputeTotalBytesWithAlignmentPaddingTest, AllSameSize) {
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({1, 1, 1, 1}),
ComputeAlignedAreaSize(4, 1));
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({7, 7, 7, 7, 7, 7}),
ComputeAlignedAreaSize(6, 7));
EXPECT_EQ(ComputeTotalBytesWithAlignmentPadding({77, 77, 77}),
ComputeAlignedAreaSize(3, 77));
}
// Tests that UniqueAlignedArray is empty by default.
TEST(UniqueAlignedArrayTest, EmptyByDefault) {
UniqueAlignedArray array;
EXPECT_EQ(array.view().size(), 0);
EXPECT_TRUE(array.view().empty());
}
// Tests that UniqueAlignedArray::Reset() always reallocates.
TEST(UniqueAlignedArrayTest, Reset) {
UniqueAlignedArray array;
// Reset to non-empty.
array.Reset(10);
const MutableAlignedView view1 = array.view();
TF_EXPECT_OK(OkIfAligned(view1.data()));
EXPECT_EQ(view1.size(), 10);
// Calling view() again should return the same byte array.
const MutableAlignedView view2 = array.view();
ExpectSameAddress(view2.data(), view1.data());
EXPECT_EQ(view2.size(), view1.size());
// Reset to a different size.
array.Reset(33);
const MutableAlignedView view3 = array.view();
TF_EXPECT_OK(OkIfAligned(view3.data()));
EXPECT_EQ(view3.size(), 33);
}
// Tests that UniqueAlignedArray::Reset() reallocates when growing.
TEST(UniqueAlignedArrayTest, Reserve) {
UniqueAlignedArray array;
// Reset to non-empty.
array.Reserve(20);
const MutableAlignedView view1 = array.view();
TF_EXPECT_OK(OkIfAligned(view1.data()));
EXPECT_EQ(view1.size(), 20);
// Shrink to a smaller size; should not reallocate.
array.Reserve(7);
const MutableAlignedView view2 = array.view();
ExpectSameAddress(view2.data(), view1.data());
EXPECT_EQ(view2.size(), 7);
// Grow but still remain within capacity; should not reallocate.
array.Reserve(14);
const MutableAlignedView view3 = array.view();
ExpectSameAddress(view3.data(), view1.data());
EXPECT_EQ(view3.size(), 14);
}
// Tests that UniqueAlignedArray::Resize() reallocates when growing and
// preserves existing contents.
TEST(UniqueAlignedArrayTest, Resize) {
UniqueAlignedArray array;
// Resize to non-empty.
EXPECT_TRUE(array.Resize(10));
const MutableAlignedView view1 = array.view();
TF_EXPECT_OK(OkIfAligned(view1.data()));
EXPECT_EQ(view1.size(), 10);
// Write some stuff.
for (int i = 0; i < 10; ++i) view1.data()[i] = '1';
// Resize to a larger size.
EXPECT_TRUE(array.Resize(33));
const MutableAlignedView view2 = array.view();
TF_EXPECT_OK(OkIfAligned(view2.data()));
EXPECT_EQ(view2.size(), 33);
// Check that content was preserved.
for (int i = 0; i < 10; ++i) EXPECT_EQ(view2.data()[i], '1');
// Append more stuff.
for (int i = 10; i < 33; ++i) view2.data()[i] = '2';
// Resize to a smaller size.
EXPECT_FALSE(array.Resize(15));
const MutableAlignedView view3 = array.view();
TF_EXPECT_OK(OkIfAligned(view3.data()));
ExpectSameAddress(view3.data(), view2.data());
EXPECT_EQ(view3.size(), 15);
// Check that content was preserved.
for (int i = 0; i < 10; ++i) EXPECT_EQ(view3.data()[i], '1');
for (int i = 10; i < 15; ++i) EXPECT_EQ(view3.data()[i], '2');
// Overwrite with new stuff.
for (int i = 0; i < 15; ++i) view3.data()[i] = '3';
// Resize to a larger size, but still below capacity.
EXPECT_FALSE(array.Resize(20));
const MutableAlignedView view4 = array.view();
TF_EXPECT_OK(OkIfAligned(view4.data()));
ExpectSameAddress(view4.data(), view2.data());
EXPECT_EQ(view4.size(), 20);
// Check that content was preserved.
for (int i = 0; i < 15; ++i) EXPECT_EQ(view4.data()[i], '3');
}
// Tests that (Mutable)AlignedView is empty by default.
TEST(AlignedViewTest, EmptyByDefault) {
AlignedView view1;
EXPECT_EQ(view1.size(), 0);
EXPECT_TRUE(view1.empty());
MutableAlignedView view2;
EXPECT_EQ(view2.size(), 0);
EXPECT_TRUE(view2.empty());
}
// Tests that (Mutable)AlignedView::Reset() works on aligned pointers.
TEST(AlignedViewTest, ResetValid) {
char *pointer = nullptr;
pointer += 3 * internal::kAlignmentBytes;
AlignedView view1;
TF_EXPECT_OK(view1.Reset(pointer, 100));
ExpectSameAddress(view1.data(), pointer);
EXPECT_EQ(view1.size(), 100);
EXPECT_FALSE(view1.empty());
MutableAlignedView view2;
TF_EXPECT_OK(view2.Reset(pointer, 100));
ExpectSameAddress(view2.data(), pointer);
EXPECT_EQ(view2.size(), 100);
EXPECT_FALSE(view2.empty());
}
// Tests that (Mutable)AlignedView::Reset() fails on misaligned pointers.
TEST(AlignedViewTest, ResetInvalid) {
char *pointer = nullptr;
++pointer; // not aligned
AlignedView view1;
EXPECT_THAT(view1.Reset(pointer, 10),
test::IsErrorWithSubstr("Pointer fails alignment requirement"));
MutableAlignedView view2;
EXPECT_THAT(view2.Reset(pointer, 10),
test::IsErrorWithSubstr("Pointer fails alignment requirement"));
}
// Tests that (Mutable)AlignedView::Reset() can empty the view.
TEST(AlignedViewTest, ResetEmpty) {
char *pointer = nullptr;
pointer += 11 * internal::kAlignmentBytes;
// First point to a non-empty byte array.
AlignedView view1;
TF_EXPECT_OK(view1.Reset(pointer, 100));
ExpectSameAddress(view1.data(), pointer);
EXPECT_EQ(view1.size(), 100);
EXPECT_FALSE(view1.empty());
// Then reset to empty.
TF_EXPECT_OK(view1.Reset(pointer, 0));
EXPECT_EQ(view1.size(), 0);
EXPECT_TRUE(view1.empty());
// First point to a non-empty byte array.
MutableAlignedView view2;
TF_EXPECT_OK(view2.Reset(pointer, 100));
ExpectSameAddress(view2.data(), pointer);
EXPECT_EQ(view2.size(), 100);
EXPECT_FALSE(view2.empty());
// Then reset to empty.
TF_EXPECT_OK(view2.Reset(pointer, 0));
EXPECT_EQ(view2.size(), 0);
EXPECT_TRUE(view2.empty());
}
// Tests that (Mutable)AlignedView supports copy-construction and assignment
// with shallow-copy semantics, and reinterprets from char* to const char*.
TEST(AlignedViewTest, CopyAndAssign) {
char *pointer1 = nullptr;
pointer1 += 3 * internal::kAlignmentBytes;
const char *pointer2 = nullptr;
pointer2 += 7 * internal::kAlignmentBytes;
MutableAlignedView view1;
TF_ASSERT_OK(view1.Reset(pointer1, 100));
AlignedView view2;
TF_ASSERT_OK(view2.Reset(pointer2, 200));
MutableAlignedView view3(view1);
ExpectSameAddress(view3.data(), pointer1);
EXPECT_EQ(view3.size(), 100);
EXPECT_FALSE(view3.empty());
view3 = MutableAlignedView();
EXPECT_EQ(view3.size(), 0);
EXPECT_TRUE(view3.empty());
view3 = view1;
ExpectSameAddress(view3.data(), pointer1);
EXPECT_EQ(view3.size(), 100);
EXPECT_FALSE(view3.empty());
AlignedView view4(view1); // reinterprets type
ExpectSameAddress(view4.data(), pointer1);
EXPECT_EQ(view4.size(), 100);
EXPECT_FALSE(view4.empty());
view4 = AlignedView();
EXPECT_EQ(view4.size(), 0);
EXPECT_TRUE(view4.empty());
view4 = view2;
ExpectSameAddress(view4.data(), pointer2);
EXPECT_EQ(view4.size(), 200);
EXPECT_FALSE(view4.empty());
view4 = view1; // reinterprets type
ExpectSameAddress(view4.data(), pointer1);
EXPECT_EQ(view4.size(), 100);
EXPECT_FALSE(view4.empty());
view4 = MutableAlignedView(); // reinterprets type
EXPECT_EQ(view4.size(), 0);
EXPECT_TRUE(view4.empty());
}
// Tests that AlignedView can split itself into sub-views with specified sizes.
TEST(AlignedViewTest, SplitConst) {
const std::vector<size_t> sizes = {1, internal::kAlignmentBytes,
internal::kAlignmentBytes + 1, 1, 123};
const size_t total_bytes = ComputeTotalBytesWithAlignmentPadding(sizes);
AlignedView view;
TF_ASSERT_OK(view.Reset(nullptr, total_bytes));
std::vector<AlignedView> views(100); // will be resized
TF_ASSERT_OK(view.Split(sizes, &views));
ASSERT_EQ(views.size(), 5);
const char *base = view.data();
ExpectSameAddress(views[0].data(), base);
EXPECT_EQ(views[0].size(), 1);
base += internal::kAlignmentBytes;
ExpectSameAddress(views[1].data(), base);
EXPECT_EQ(views[1].size(), internal::kAlignmentBytes);
base += internal::kAlignmentBytes;
ExpectSameAddress(views[2].data(), base);
EXPECT_EQ(views[2].size(), internal::kAlignmentBytes + 1);
base += 2 * internal::kAlignmentBytes;
ExpectSameAddress(views[3].data(), base);
EXPECT_EQ(views[3].size(), 1);
base += internal::kAlignmentBytes;
ExpectSameAddress(views[4].data(), base);
EXPECT_EQ(views[4].size(), 123);
}
// Tests that MutableAlignedView can split itself into sub-views with specified
// sizes, and reinterprets from char* to const char*.
TEST(AlignedViewTest, SplitMutable) {
const std::vector<size_t> sizes = {1, internal::kAlignmentBytes,
internal::kAlignmentBytes + 1, 1, 123};
const size_t total_bytes = ComputeTotalBytesWithAlignmentPadding(sizes);
// Also add some padding to check that we can split part of the view.
MutableAlignedView view;
TF_ASSERT_OK(view.Reset(nullptr, total_bytes + 10));
std::vector<AlignedView> const_views(99); // will be resized
std::vector<MutableAlignedView> mutable_views(2); // will be resized
TF_ASSERT_OK(view.Split(sizes, &const_views));
TF_ASSERT_OK(view.Split(sizes, &mutable_views));
ASSERT_EQ(const_views.size(), 5);
ASSERT_EQ(mutable_views.size(), 5);
const char *base = view.data();
ExpectSameAddress(const_views[0].data(), base);
ExpectSameAddress(mutable_views[0].data(), base);
EXPECT_EQ(const_views[0].size(), 1);
EXPECT_EQ(mutable_views[0].size(), 1);
base += internal::kAlignmentBytes;
ExpectSameAddress(const_views[1].data(), base);
ExpectSameAddress(mutable_views[1].data(), base);
EXPECT_EQ(const_views[1].size(), internal::kAlignmentBytes);
EXPECT_EQ(mutable_views[1].size(), internal::kAlignmentBytes);
base += internal::kAlignmentBytes;
ExpectSameAddress(const_views[2].data(), base);
ExpectSameAddress(mutable_views[2].data(), base);
EXPECT_EQ(const_views[2].size(), internal::kAlignmentBytes + 1);
EXPECT_EQ(mutable_views[2].size(), internal::kAlignmentBytes + 1);
base += 2 * internal::kAlignmentBytes;
ExpectSameAddress(const_views[3].data(), base);
ExpectSameAddress(mutable_views[3].data(), base);
EXPECT_EQ(const_views[3].size(), 1);
EXPECT_EQ(mutable_views[3].size(), 1);
base += internal::kAlignmentBytes;
ExpectSameAddress(const_views[4].data(), base);
ExpectSameAddress(mutable_views[4].data(), base);
EXPECT_EQ(const_views[4].size(), 123);
EXPECT_EQ(mutable_views[4].size(), 123);
}
TEST(AlignedViewTest, SplitTooSmall) {
const std::vector<size_t> sizes = {1, internal::kAlignmentBytes,
internal::kAlignmentBytes + 1, 1, 123};
const size_t total_bytes = ComputeTotalBytesWithAlignmentPadding(sizes);
// Make the view just a bit too small.
MutableAlignedView view;
TF_ASSERT_OK(view.Reset(nullptr, total_bytes - 1));
std::vector<MutableAlignedView> views;
EXPECT_THAT(view.Split(sizes, &views),
test::IsErrorWithSubstr("View is too small to be split"));
}
// Tests that (Mutable)AlignedArea is empty by default.
TEST(AlignedAreaTest, EmptyByDefault) {
AlignedArea area1;
EXPECT_EQ(area1.num_views(), 0);
EXPECT_EQ(area1.view_size(), 0);
EXPECT_TRUE(area1.empty());
MutableAlignedArea area2;
EXPECT_EQ(area2.num_views(), 0);
EXPECT_EQ(area2.view_size(), 0);
EXPECT_TRUE(area2.empty());
}
// Tests that (Mutable)AlignedArea::Reset() can initialize to a single view.
TEST(AlignedAreaTest, ResetSingleton) {
const char *pointer1 = nullptr;
pointer1 += 3 * internal::kAlignmentBytes;
char *pointer2 = nullptr;
pointer2 += 7 * internal::kAlignmentBytes;
AlignedView view1;
TF_ASSERT_OK(view1.Reset(pointer1, internal::kAlignmentBytes));
MutableAlignedView view2;
TF_ASSERT_OK(view2.Reset(pointer2, internal::kAlignmentBytes + 1));
AlignedArea area1;
TF_ASSERT_OK(area1.Reset(view1, 1, 1));
EXPECT_EQ(area1.num_views(), 1);
EXPECT_EQ(area1.view_size(), 1);
EXPECT_FALSE(area1.empty());
ExpectSameAddress(area1.view(0).data(), pointer1);
EXPECT_EQ(area1.view(0).size(), 1);
TF_ASSERT_OK(area1.Reset(view2, 1, 2));
EXPECT_EQ(area1.num_views(), 1);
EXPECT_EQ(area1.view_size(), 2);
EXPECT_FALSE(area1.empty());
ExpectSameAddress(area1.view(0).data(), pointer2);
EXPECT_EQ(area1.view(0).size(), 2);
TF_ASSERT_OK(area1.Reset(view2, 1, 1));
EXPECT_EQ(area1.num_views(), 1);
EXPECT_EQ(area1.view_size(), 1);
EXPECT_FALSE(area1.empty());
ExpectSameAddress(area1.view(0).data(), pointer2);
EXPECT_EQ(area1.view(0).size(), 1);
MutableAlignedArea area2;
TF_ASSERT_OK(area2.Reset(view2, 1, 2));
EXPECT_EQ(area2.num_views(), 1);
EXPECT_EQ(area2.view_size(), 2);
EXPECT_FALSE(area2.empty());
ExpectSameAddress(area2.view(0).data(), pointer2);
EXPECT_EQ(area2.view(0).size(), 2);
TF_ASSERT_OK(area2.Reset(view2, 1, 1));
EXPECT_EQ(area2.num_views(), 1);
EXPECT_EQ(area2.view_size(), 1);
EXPECT_FALSE(area2.empty());
ExpectSameAddress(area2.view(0).data(), pointer2);
EXPECT_EQ(area2.view(0).size(), 1);
}
// Tests that (Mutable)AlignedArea::Reset() can initialize to a sequence of
// multiple views.
TEST(AlignedAreaTest, ResetMultiple) {
const char *pointer1 = nullptr;
pointer1 += 3 * internal::kAlignmentBytes;
char *pointer2 = nullptr;
pointer2 += 7 * internal::kAlignmentBytes;
AlignedView view1;
TF_ASSERT_OK(view1.Reset(pointer1, 11 * internal::kAlignmentBytes));
MutableAlignedView view2;
TF_ASSERT_OK(view2.Reset(pointer2, 2 * internal::kAlignmentBytes));
AlignedArea area1;
TF_ASSERT_OK(area1.Reset(view1, 11, 1));
EXPECT_EQ(area1.num_views(), 11);
EXPECT_EQ(area1.view_size(), 1);
EXPECT_FALSE(area1.empty());
for (int i = 0; i < area1.num_views(); ++i) {
ExpectSameAddress(area1.view(i).data(),
pointer1 + internal::kAlignmentBytes * i);
EXPECT_EQ(area1.view(i).size(), 1);
}
TF_ASSERT_OK(area1.Reset(view1, 10, internal::kAlignmentBytes));
EXPECT_EQ(area1.num_views(), 10);
EXPECT_EQ(area1.view_size(), internal::kAlignmentBytes);
EXPECT_FALSE(area1.empty());
for (int i = 0; i < area1.num_views(); ++i) {
ExpectSameAddress(area1.view(i).data(),
pointer1 + internal::kAlignmentBytes * i);
EXPECT_EQ(area1.view(i).size(), internal::kAlignmentBytes);
}
TF_ASSERT_OK(area1.Reset(view2, 2, 2));
EXPECT_EQ(area1.num_views(), 2);
EXPECT_EQ(area1.view_size(), 2);
EXPECT_FALSE(area1.empty());
for (int i = 0; i < area1.num_views(); ++i) {
ExpectSameAddress(area1.view(i).data(),
pointer2 + internal::kAlignmentBytes * i);
EXPECT_EQ(area1.view(i).size(), 2);
}
MutableAlignedArea area2;
TF_ASSERT_OK(area2.Reset(view2, 2, internal::kAlignmentBytes));
EXPECT_EQ(area2.num_views(), 2);
EXPECT_EQ(area2.view_size(), internal::kAlignmentBytes);
EXPECT_FALSE(area2.empty());
for (int i = 0; i < area2.num_views(); ++i) {
ExpectSameAddress(area2.view(i).data(),
pointer2 + internal::kAlignmentBytes * i);
EXPECT_EQ(area2.view(i).size(), internal::kAlignmentBytes);
}
}
// Tests that (Mutable)AlignedArea::Reset() fails when the view being split into
// sub-views is too small.
TEST(AlignedAreaTest, ResetInvalid) {
AlignedView view1;
TF_ASSERT_OK(view1.Reset(nullptr, 11 * internal::kAlignmentBytes));
MutableAlignedView view2;
TF_ASSERT_OK(view2.Reset(nullptr, 2 * internal::kAlignmentBytes));
// View size larger than available view.
AlignedArea area;
EXPECT_THAT(area.Reset(view1, 1, 11 * internal::kAlignmentBytes + 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view1, 11, 1));
EXPECT_THAT(area.Reset(view2, 1, 2 * internal::kAlignmentBytes + 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view2, 2, 1));
// Total size larger than available view.
EXPECT_THAT(area.Reset(view1, 12, 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view1, 11, 1));
EXPECT_THAT(area.Reset(view1, 4, 2 * internal::kAlignmentBytes + 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view1, 11, 1));
EXPECT_THAT(area.Reset(view1, 3, 3 * internal::kAlignmentBytes + 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view1, 11, 1));
EXPECT_THAT(area.Reset(view1, 2, 5 * internal::kAlignmentBytes + 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view1, 11, 1));
EXPECT_THAT(area.Reset(view2, 3, 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view2, 2, 1));
EXPECT_THAT(area.Reset(view2, 2, internal::kAlignmentBytes + 1),
test::IsErrorWithSubstr("View is too small for area"));
TF_ASSERT_OK(area.Reset(view2, 2, 1));
}
// Tests that (Mutable)AlignedView::Reset() can empty the area.
TEST(AlignedAreaTest, ResetEmpty) {
AlignedView view1;
TF_ASSERT_OK(view1.Reset(nullptr, 11 * internal::kAlignmentBytes));
MutableAlignedView view2;
TF_ASSERT_OK(view2.Reset(nullptr, 2 * internal::kAlignmentBytes));
// First point to a non-empty byte array, then clear.
AlignedArea area1;
TF_ASSERT_OK(area1.Reset(view1, 11, 1));
TF_ASSERT_OK(area1.Reset(view1, 0, 0));
EXPECT_EQ(area1.num_views(), 0);
EXPECT_EQ(area1.view_size(), 0);
EXPECT_TRUE(area1.empty());
TF_ASSERT_OK(area1.Reset(view2, 2, 1));
TF_ASSERT_OK(area1.Reset(view2, 0, 100));
EXPECT_EQ(area1.num_views(), 0);
EXPECT_EQ(area1.view_size(), 100);
EXPECT_TRUE(area1.empty());
TF_ASSERT_OK(area1.Reset(view2, 2, 1));
TF_ASSERT_OK(area1.Reset(MutableAlignedView(), 0, 1));
EXPECT_EQ(area1.num_views(), 0);
EXPECT_EQ(area1.view_size(), 1);
EXPECT_TRUE(area1.empty());
MutableAlignedArea area2;
TF_ASSERT_OK(area2.Reset(view2, 2, 1));
TF_ASSERT_OK(area2.Reset(view2, 0, 0));
EXPECT_EQ(area2.num_views(), 0);
EXPECT_EQ(area2.view_size(), 0);
EXPECT_TRUE(area2.empty());
TF_ASSERT_OK(area2.Reset(view2, 2, 1));
TF_ASSERT_OK(area2.Reset(view2, 0, 100));
EXPECT_EQ(area2.num_views(), 0);
EXPECT_EQ(area2.view_size(), 100);
EXPECT_TRUE(area2.empty());
TF_ASSERT_OK(area2.Reset(view2, 2, 1));
TF_ASSERT_OK(area2.Reset(MutableAlignedView(), 0, 1));
EXPECT_EQ(area2.num_views(), 0);
EXPECT_EQ(area2.view_size(), 1);
EXPECT_TRUE(area2.empty());
}
// Tests that (Mutable)AlignedArea supports copy-construction and assignment
// with shallow-copy semantics, and reinterprets from char* to const char*.
TEST(AlignedAreaTest, CopyAndAssign) {
char *pointer1 = nullptr;
pointer1 += 3 * internal::kAlignmentBytes;
const char *pointer2 = nullptr;
pointer2 += 7 * internal::kAlignmentBytes;
MutableAlignedView view1;
TF_ASSERT_OK(view1.Reset(pointer1, ComputeAlignedAreaSize(1, 5)));
AlignedView view2;
TF_ASSERT_OK(view2.Reset(pointer2, ComputeAlignedAreaSize(2, 77)));
MutableAlignedArea area1;
TF_ASSERT_OK(area1.Reset(view1, 1, 5));
AlignedArea area2;
TF_ASSERT_OK(area2.Reset(view2, 2, 77));
MutableAlignedArea area3(area1);
EXPECT_EQ(area3.num_views(), 1);
EXPECT_EQ(area3.view_size(), 5);
EXPECT_FALSE(area3.empty());
ExpectSameAddress(area3.view(0).data(), pointer1);
EXPECT_EQ(area3.view(0).size(), 5);
area3 = MutableAlignedArea();
EXPECT_EQ(area3.num_views(), 0);
EXPECT_EQ(area3.view_size(), 0);
EXPECT_TRUE(area3.empty());
area3 = area1;
EXPECT_EQ(area3.num_views(), 1);
EXPECT_EQ(area3.view_size(), 5);
EXPECT_FALSE(area3.empty());
ExpectSameAddress(area3.view(0).data(), pointer1);
EXPECT_EQ(area3.view(0).size(), 5);
AlignedArea area4(area1); // reinterprets type
EXPECT_EQ(area4.num_views(), 1);
EXPECT_EQ(area4.view_size(), 5);
EXPECT_FALSE(area4.empty());
ExpectSameAddress(area4.view(0).data(), pointer1);
EXPECT_EQ(area4.view(0).size(), 5);
area4 = AlignedArea();
EXPECT_EQ(area4.num_views(), 0);
EXPECT_EQ(area4.view_size(), 0);
EXPECT_TRUE(area4.empty());
area4 = area2;
EXPECT_EQ(area4.num_views(), 2);
EXPECT_EQ(area4.view_size(), 77);
EXPECT_FALSE(area4.empty());
ExpectSameAddress(area4.view(0).data(), pointer2);
EXPECT_EQ(area4.view(0).size(), 77);
ExpectSameAddress(area4.view(1).data(), PadToAlignment(pointer2 + 77));
EXPECT_EQ(area4.view(1).size(), 77);
area4 = area1; // reinterprets type
EXPECT_EQ(area4.num_views(), 1);
EXPECT_EQ(area4.view_size(), 5);
EXPECT_FALSE(area4.empty());
ExpectSameAddress(area4.view(0).data(), pointer1);
EXPECT_EQ(area4.view(0).size(), 5);
area4 = MutableAlignedArea(); // reinterprets type
EXPECT_EQ(area4.num_views(), 0);
EXPECT_EQ(area4.view_size(), 0);
EXPECT_TRUE(area4.empty());
}
} // namespace
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "dragnn/runtime/array_variable_store.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/cpu_info.h"
namespace syntaxnet {
namespace dragnn {
namespace runtime {
// Increment this if the serialized format changes in an incompatible way that
// can't be detected through other means. For example,
// * If kAlignmentBytes is changed, then kVersion need not change because there
// is a separate field for detecting alignment mismatch.
// * If ArrayVariableStoreSpec.variable is no longer populated, perhaps replaced
// by some other approach, then kVersion should be incremented.
const uint32 ArrayVariableStore::kVersion = 0;
tensorflow::Status ArrayVariableStore::Reset(const ArrayVariableStoreSpec &spec,
AlignedView data) {
if (!spec.has_version() || !spec.has_alignment_bytes() ||
!spec.has_is_little_endian()) {
return tensorflow::errors::InvalidArgument(
"ArrayVariableStoreSpec is missing a required field: ",
spec.ShortDebugString());
}
if (spec.version() != kVersion) {
return tensorflow::errors::InvalidArgument(
"ArrayVariableStoreSpec.version (", spec.version(),
") does not match the binary (", kVersion, ")");
}
if (spec.alignment_bytes() != internal::kAlignmentBytes) {
return tensorflow::errors::InvalidArgument(
"ArrayVariableStoreSpec.alignment_bytes (", spec.alignment_bytes(),
") does not match the binary (", internal::kAlignmentBytes, ")");
}
// TODO(googleuser): It should be possible to correct an endian-ness mismatch.
// A rough outline is:
// * VariableStore::Lookup() takes an additional argument set to sizeof(T).
// * Capture sizeof(T) and write it into the VariableSpec.
// * Detect endian mismatch and byte-swap variables with multi-byte types.
if (spec.is_little_endian() != tensorflow::port::kLittleEndian) {
return tensorflow::errors::InvalidArgument(
"ArrayVariableStoreSpec.is_little_endian (", spec.is_little_endian(),
") does not match the binary (", tensorflow::port::kLittleEndian, ")");
}
for (const VariableSpec &variable_spec : spec.variable()) {
// When the proto parser encounters an unknown enumerator on the wire, it
// replaces it with the default value (i.e., FORMAT_UNKNOWN). Therefore,
// VariableSpec.format() will always return a valid enumerator.
DCHECK(VariableSpec::Format_IsValid(variable_spec.format()));
if (variable_spec.format() == VariableSpec::FORMAT_UNKNOWN) {
return tensorflow::errors::InvalidArgument(
"Unknown variable format: ", variable_spec.ShortDebugString());
}
if (variable_spec.format() == VariableSpec::FORMAT_FLAT &&
variable_spec.num_views() != 1) {
return tensorflow::errors::InvalidArgument(
"Flat variables must have 1 view: ",
variable_spec.ShortDebugString());
}
}
// Build into a temp mapping to avoid modification on error.
std::unique_ptr<std::map<Key, Value>> new_variables(
new std::map<Key, Value>());
// Slice sub-arrays off of the main byte array.
const char *base = data.data();
const char *const end = base + data.size();
for (const VariableSpec &variable_spec : spec.variable()) {
const size_t num_views = variable_spec.num_views();
const size_t view_size = variable_spec.view_size();
const size_t area_size = ComputeAlignedAreaSize(num_views, view_size);
if (base + area_size > end) {
return tensorflow::errors::InvalidArgument(
"Variable would overrun main byte array: ",
variable_spec.ShortDebugString());
}
AlignedView view;
TF_RETURN_IF_ERROR(view.Reset(base, area_size));
base += area_size; // remove claimed slice
// Set dimensions from the spec.
std::vector<size_t> dimensions(variable_spec.dimension().begin(),
variable_spec.dimension().end());
Value value(std::move(dimensions), AlignedArea());
AlignedArea &area = value.second;
TF_RETURN_IF_ERROR(area.Reset(view, num_views, view_size));
// Currently, blocked variables are meant for fast inference algorithms,
// which do not tolerate padding. Raise errors if there is padding.
if (variable_spec.format() ==
VariableSpec::FORMAT_COLUMN_BLOCKED_ROW_MAJOR_MATRIX) {
size_t padding = variable_spec.view_size() % internal::kAlignmentBytes;
if (padding != 0) {
return tensorflow::errors::Internal(
"Currently, fast matrix-vector operations do not support padded "
"blocked matrices, but variable '",
variable_spec.name(), "' has padding ", padding);
}
}
const Key key(variable_spec.name(), variable_spec.format());
if (!new_variables->emplace(key, value).second) {
return tensorflow::errors::InvalidArgument(
"Duplicate variable: ", variable_spec.ShortDebugString());
}
}
if (base != end) {
return tensorflow::errors::InvalidArgument(
"Variables do not completely cover main byte array: ", end - base,
" bytes remaining");
}
// Success; make modifications.
variables_ = std::move(new_variables);
return tensorflow::Status::OK();
}
tensorflow::Status ArrayVariableStore::Lookup(const string &name,
VariableSpec::Format format,
std::vector<size_t> *dimensions,
AlignedArea *area) {
if (!variables_) {
return tensorflow::errors::FailedPrecondition(
"ArrayVariableStore not initialized");
}
const Key key(name, format);
const auto it = variables_->find(key);
if (it == variables_->end()) {
return tensorflow::errors::NotFound(
"ArrayVariableStore has no variable with name '", name, "' and format ",
VariableSpec::Format_Name(format));
}
// Success; make modifications.
const Value &value = it->second;
*dimensions = value.first;
*area = value.second;
return tensorflow::Status::OK();
}
tensorflow::Status ArrayVariableStore::Close() {
if (!variables_) {
return tensorflow::errors::FailedPrecondition(
"ArrayVariableStore not initialized");
}
variables_.reset();
return tensorflow::Status::OK();
}
} // namespace runtime
} // namespace dragnn
} // namespace syntaxnet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment