digraph_ops.py 14 KB
Newer Older
Ivan Bogatyy's avatar
Ivan Bogatyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
"""TensorFlow ops for directed graphs."""

import tensorflow as tf

from syntaxnet.util import check


def ArcPotentialsFromTokens(source_tokens, target_tokens, weights):
  r"""Returns arc potentials computed from token activations and weights.

  For each batch of source and target token activations, computes a scalar
  potential for each arc as the 3-way product between the activation vectors of
  the source and target of the arc and the |weights|.  Specifically,

    arc[b,s,t] =
        \sum_{i,j} source_tokens[b,s,i] * weights[i,j] * target_tokens[b,t,j]

  Note that the token activations can be extended with bias terms to implement a
  "biaffine" model (Dozat and Manning, 2017).

  Args:
    source_tokens: [B,N,S] tensor of batched activations for the source token in
                   each arc.
    target_tokens: [B,N,T] tensor of batched activations for the target token in
                   each arc.
    weights: [S,T] matrix of weights.

    B,N may be statically-unknown, but S,T must be statically-known.  The dtype
    of all arguments must be compatible.

  Returns:
    [B,N,N] tensor A of arc potentials where A_{b,s,t} is the potential of the
    arc from s to t in batch element b.  The dtype of A is the same as that of
    the arguments.  Note that the diagonal entries (i.e., where s==t) represent
    self-loops and may not be meaningful.
  """
  # All arguments must have statically-known rank.
  check.Eq(source_tokens.get_shape().ndims, 3, 'source_tokens must be rank 3')
  check.Eq(target_tokens.get_shape().ndims, 3, 'target_tokens must be rank 3')
  check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')

  # All activation dimensions must be statically-known.
  num_source_activations = weights.get_shape().as_list()[0]
  num_target_activations = weights.get_shape().as_list()[1]
  check.NotNone(num_source_activations, 'unknown source activation dimension')
  check.NotNone(num_target_activations, 'unknown target activation dimension')
  check.Eq(source_tokens.get_shape().as_list()[2], num_source_activations,
           'dimension mismatch between weights and source_tokens')
  check.Eq(target_tokens.get_shape().as_list()[2], num_target_activations,
           'dimension mismatch between weights and target_tokens')

  # All arguments must share the same type.
  check.Same([weights.dtype.base_dtype,
              source_tokens.dtype.base_dtype,
              target_tokens.dtype.base_dtype],
             'dtype mismatch')

  source_tokens_shape = tf.shape(source_tokens)
  target_tokens_shape = tf.shape(target_tokens)
  batch_size = source_tokens_shape[0]
  num_tokens = source_tokens_shape[1]
  with tf.control_dependencies([
      tf.assert_equal(batch_size, target_tokens_shape[0]),
      tf.assert_equal(num_tokens, target_tokens_shape[1])]):
    # Flatten out the batch dimension so we can use one big multiplication.
    targets_bnxt = tf.reshape(target_tokens, [-1, num_target_activations])

    # Matrices are row-major, so we arrange for the RHS argument of each matmul
    # to have its transpose flag set.  That way no copying is required to align
    # the rows of the LHS with the columns of the RHS.
    weights_targets_bnxs = tf.matmul(targets_bnxt, weights, transpose_b=True)

    # The next computation is over pairs of tokens within each batch element, so
    # restore the batch dimension.
    weights_targets_bxnxs = tf.reshape(
        weights_targets_bnxs, [batch_size, num_tokens, num_source_activations])

    # Note that this multiplication is repeated across the batch dimension,
    # instead of being one big multiplication as in the first matmul.  There
    # doesn't seem to be a way to arrange this as a single multiplication given
    # the pairwise nature of this computation.
    arcs_bxnxn = tf.matmul(source_tokens, weights_targets_bxnxs,
                           transpose_b=True)
    return arcs_bxnxn


def ArcSourcePotentialsFromTokens(tokens, weights):
  r"""Returns arc source potentials computed from tokens and weights.

  For each batch of token activations, computes a scalar potential for each arc
  as the product between the activations of the source token and the |weights|.
  Specifically,

    arc[b,s,:] = \sum_{i} weights[i] * tokens[b,s,i]

  Args:
    tokens: [B,N,S] tensor of batched activations for source tokens.
    weights: [S] vector of weights.

    B,N may be statically-unknown, but S must be statically-known.  The dtype of
    all arguments must be compatible.

  Returns:
    [B,N,N] tensor A of arc potentials as defined above.  The dtype of A is the
    same as that of the arguments.  Note that the diagonal entries (i.e., where
    s==t) represent self-loops and may not be meaningful.
  """
  # All arguments must have statically-known rank.
  check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
  check.Eq(weights.get_shape().ndims, 1, 'weights must be a vector')

  # All activation dimensions must be statically-known.
  num_source_activations = weights.get_shape().as_list()[0]
  check.NotNone(num_source_activations, 'unknown source activation dimension')
  check.Eq(tokens.get_shape().as_list()[2], num_source_activations,
           'dimension mismatch between weights and tokens')

  # All arguments must share the same type.
  check.Same([weights.dtype.base_dtype,
              tokens.dtype.base_dtype],
             'dtype mismatch')

  tokens_shape = tf.shape(tokens)
  batch_size = tokens_shape[0]
  num_tokens = tokens_shape[1]

  # Flatten out the batch dimension so we can use a couple big matmuls.
  tokens_bnxs = tf.reshape(tokens, [-1, num_source_activations])
  weights_sx1 = tf.expand_dims(weights, 1)
  sources_bnx1 = tf.matmul(tokens_bnxs, weights_sx1)
  sources_bnxn = tf.tile(sources_bnx1, [1, num_tokens])

  # Restore the batch dimension in the output.
  sources_bxnxn = tf.reshape(sources_bnxn, [batch_size, num_tokens, num_tokens])
  return sources_bxnxn


def RootPotentialsFromTokens(root, tokens, weights):
  r"""Returns root selection potentials computed from tokens and weights.

  For each batch of token activations, computes a scalar potential for each root
  selection as the 3-way product between the activations of the artificial root
  token, the token activations, and the |weights|.  Specifically,

    roots[b,r] = \sum_{i,j} root[i] * weights[i,j] * tokens[b,r,j]

  Args:
    root: [S] vector of activations for the artificial root token.
    tokens: [B,N,T] tensor of batched activations for root tokens.
    weights: [S,T] matrix of weights.

    B,N may be statically-unknown, but S,T must be statically-known.  The dtype
    of all arguments must be compatible.

  Returns:
    [B,N] matrix R of root-selection potentials as defined above.  The dtype of
    R is the same as that of the arguments.
  """
  # All arguments must have statically-known rank.
  check.Eq(root.get_shape().ndims, 1, 'root must be a vector')
  check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
  check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')

  # All activation dimensions must be statically-known.
  num_source_activations = weights.get_shape().as_list()[0]
  num_target_activations = weights.get_shape().as_list()[1]
  check.NotNone(num_source_activations, 'unknown source activation dimension')
  check.NotNone(num_target_activations, 'unknown target activation dimension')
  check.Eq(root.get_shape().as_list()[0], num_source_activations,
           'dimension mismatch between weights and root')
  check.Eq(tokens.get_shape().as_list()[2], num_target_activations,
           'dimension mismatch between weights and tokens')

  # All arguments must share the same type.
  check.Same([weights.dtype.base_dtype,
              root.dtype.base_dtype,
              tokens.dtype.base_dtype],
             'dtype mismatch')

  root_1xs = tf.expand_dims(root, 0)

  tokens_shape = tf.shape(tokens)
  batch_size = tokens_shape[0]
  num_tokens = tokens_shape[1]

  # Flatten out the batch dimension so we can use a couple big matmuls.
  tokens_bnxt = tf.reshape(tokens, [-1, num_target_activations])
  weights_targets_bnxs = tf.matmul(tokens_bnxt, weights, transpose_b=True)
  roots_1xbn = tf.matmul(root_1xs, weights_targets_bnxs, transpose_b=True)

  # Restore the batch dimension in the output.
  roots_bxn = tf.reshape(roots_1xbn, [batch_size, num_tokens])
  return roots_bxn


def CombineArcAndRootPotentials(arcs, roots):
  """Combines arc and root potentials into a single set of potentials.

  Args:
    arcs: [B,N,N] tensor of batched arc potentials.
    roots: [B,N] matrix of batched root potentials.

  Returns:
    [B,N,N] tensor P of combined potentials where
      P_{b,s,t} = s == t ? roots[b,t] : arcs[b,s,t]
  """
  # All arguments must have statically-known rank.
  check.Eq(arcs.get_shape().ndims, 3, 'arcs must be rank 3')
  check.Eq(roots.get_shape().ndims, 2, 'roots must be a matrix')

  # All arguments must share the same type.
  dtype = arcs.dtype.base_dtype
  check.Same([dtype, roots.dtype.base_dtype], 'dtype mismatch')

  roots_shape = tf.shape(roots)
  arcs_shape = tf.shape(arcs)
  batch_size = roots_shape[0]
  num_tokens = roots_shape[1]
  with tf.control_dependencies([
      tf.assert_equal(batch_size, arcs_shape[0]),
      tf.assert_equal(num_tokens, arcs_shape[1]),
      tf.assert_equal(num_tokens, arcs_shape[2])]):
    return tf.matrix_set_diag(arcs, roots)


def LabelPotentialsFromTokens(tokens, weights):
  r"""Computes label potentials from tokens and weights.

  For each batch of token activations, computes a scalar potential for each
  label as the product between the activations of the source token and the
  |weights|.  Specifically,

    labels[b,t,l] = \sum_{i} weights[l,i] * tokens[b,t,i]

  Args:
    tokens: [B,N,T] tensor of batched token activations.
    weights: [L,T] matrix of weights.

    B,N may be dynamic, but L,T must be static.  The dtype of all arguments must
    be compatible.

  Returns:
    [B,N,L] tensor of label potentials as defined above, with the same dtype as
    the arguments.
  """
  check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
  check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')

  num_labels = weights.get_shape().as_list()[0]
  num_activations = weights.get_shape().as_list()[1]
  check.NotNone(num_labels, 'unknown number of labels')
  check.NotNone(num_activations, 'unknown activation dimension')
  check.Eq(tokens.get_shape().as_list()[2], num_activations,
           'activation mismatch between weights and tokens')
  tokens_shape = tf.shape(tokens)
  batch_size = tokens_shape[0]
  num_tokens = tokens_shape[1]

  check.Same([tokens.dtype.base_dtype,
              weights.dtype.base_dtype],
             'dtype mismatch')

  # Flatten out the batch dimension so we can use one big matmul().
  tokens_bnxt = tf.reshape(tokens, [-1, num_activations])
  labels_bnxl = tf.matmul(tokens_bnxt, weights, transpose_b=True)

  # Restore the batch dimension in the output.
  labels_bxnxl = tf.reshape(labels_bnxl, [batch_size, num_tokens, num_labels])
  return labels_bxnxl


def LabelPotentialsFromTokenPairs(sources, targets, weights):
  r"""Computes label potentials from source and target tokens and weights.

  For each aligned pair of source and target token activations, computes a
  scalar potential for each label on the arc from the source to the target.
  Specifically,

    labels[b,t,l] = \sum_{i,j} sources[b,t,i] * weights[l,i,j] * targets[b,t,j]

  Args:
    sources: [B,N,S] tensor of batched source token activations.
    targets: [B,N,T] tensor of batched target token activations.
    weights: [L,S,T] tensor of weights.

    B,N may be dynamic, but L,S,T must be static.  The dtype of all arguments
    must be compatible.

  Returns:
    [B,N,L] tensor of label potentials as defined above, with the same dtype as
    the arguments.
  """
  check.Eq(sources.get_shape().ndims, 3, 'sources must be rank 3')
  check.Eq(targets.get_shape().ndims, 3, 'targets must be rank 3')
  check.Eq(weights.get_shape().ndims, 3, 'weights must be rank 3')

  num_labels = weights.get_shape().as_list()[0]
  num_source_activations = weights.get_shape().as_list()[1]
  num_target_activations = weights.get_shape().as_list()[2]
  check.NotNone(num_labels, 'unknown number of labels')
  check.NotNone(num_source_activations, 'unknown source activation dimension')
  check.NotNone(num_target_activations, 'unknown target activation dimension')
  check.Eq(sources.get_shape().as_list()[2], num_source_activations,
           'activation mismatch between weights and source tokens')
  check.Eq(targets.get_shape().as_list()[2], num_target_activations,
           'activation mismatch between weights and target tokens')

  check.Same([sources.dtype.base_dtype,
              targets.dtype.base_dtype,
              weights.dtype.base_dtype],
             'dtype mismatch')

  sources_shape = tf.shape(sources)
  targets_shape = tf.shape(targets)
  batch_size = sources_shape[0]
  num_tokens = sources_shape[1]
  with tf.control_dependencies([tf.assert_equal(batch_size, targets_shape[0]),
                                tf.assert_equal(num_tokens, targets_shape[1])]):
    # For each token, we must compute a vector-3tensor-vector product.  There is
    # no op for this, but we can use reshape() and matmul() to compute it.

    # Reshape |weights| and |targets| so we can use a single matmul().
    weights_lsxt = tf.reshape(weights, [num_labels * num_source_activations,
                                        num_target_activations])
    targets_bnxt = tf.reshape(targets, [-1, num_target_activations])
    weights_targets_bnxls = tf.matmul(targets_bnxt, weights_lsxt,
                                      transpose_b=True)

    # Restore all dimensions.
    weights_targets_bxnxlxs = tf.reshape(
        weights_targets_bnxls,
        [batch_size, num_tokens, num_labels, num_source_activations])

    # Incorporate the source activations.  In this case, we perform a batched
    # matmul() between the trailing [L,S] matrices of the current result and the
    # trailing [S] vectors of the tokens.
    sources_bxnx1xs = tf.expand_dims(sources, 2)
    labels_bxnxlx1 = tf.matmul(weights_targets_bxnxlxs, sources_bxnx1xs,
                               transpose_b=True)
    labels_bxnxl = tf.squeeze(labels_bxnxlx1, [3])
    return labels_bxnxl